Skip to content

Commit 69095dd

Browse files
authored
Adding MViT v2 architecture (#6105)
* Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations.
1 parent d4a03fc commit 69095dd

9 files changed

+631
-0
lines changed

docs/source/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ pre-trained weights:
459459
.. toctree::
460460
:maxdepth: 1
461461

462+
models/video_mvitv2
462463
models/video_resnet
463464

464465
|

docs/source/models/video_mvitv2.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
Video ResNet
2+
============
3+
4+
.. currentmodule:: torchvision.models.video
5+
6+
The MViT V2 model is based on the
7+
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
8+
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
9+
<https://arxiv.org/abs/2104.11227>`__ papers.
10+
11+
12+
Model builders
13+
--------------
14+
15+
The following model builders can be used to instantiate a MViTV2 model, with or
16+
without pre-trained weights. All the model builders internally rely on the
17+
``torchvision.models.video.MViTV2`` base class. Please refer to the `source
18+
code
19+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_ for
20+
more details about this class.
21+
22+
.. autosummary::
23+
:toctree: generated/
24+
:template: function.rst
25+
26+
mvit_v2_t
27+
mvit_v2_s
28+
mvit_v2_b
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

test/test_extended_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_schema_meta_validation(model_fn):
8787
"license",
8888
"_metrics",
8989
"min_size",
90+
"min_temporal_size",
9091
"num_params",
9192
"recipe",
9293
"unquantized",

test/test_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,15 @@ def _check_input_backprop(model, inputs):
309309
"image_size": 56,
310310
"input_shape": (1, 3, 56, 56),
311311
},
312+
"mvit_v2_t": {
313+
"input_shape": (1, 3, 16, 224, 224),
314+
},
315+
"mvit_v2_s": {
316+
"input_shape": (1, 3, 16, 224, 224),
317+
},
318+
"mvit_v2_b": {
319+
"input_shape": (1, 3, 32, 224, 224),
320+
},
312321
}
313322
# speeding up slow models:
314323
slow_models = [
@@ -338,6 +347,7 @@ def _check_input_backprop(model, inputs):
338347
skipped_big_models = {
339348
"vit_h_14",
340349
"regnet_y_128gf",
350+
"mvit_v2_b",
341351
}
342352

343353
# The following contains configuration and expected values to be used tests that are model specific
@@ -830,6 +840,8 @@ def test_video_model(model_fn, dev):
830840
"num_classes": 50,
831841
}
832842
model_name = model_fn.__name__
843+
if dev == "cuda" and SKIP_BIG_MODEL and model_name in skipped_big_models:
844+
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
833845
kwargs = {**defaults, **_model_params.get(model_name, {})}
834846
num_classes = kwargs.get("num_classes")
835847
input_shape = kwargs.pop("input_shape")

torchvision/models/video/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .mvitv2 import *
12
from .resnet import *

0 commit comments

Comments
 (0)