Skip to content

Commit 683d520

Browse files
committed
Change naming conventions of classes to follow the same pattern as MobileNetV3
1 parent 1330d9c commit 683d520

File tree

6 files changed

+41
-33
lines changed

6 files changed

+41
-33
lines changed

docs/source/models/video_mvitv2.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Video ResNet
33

44
.. currentmodule:: torchvision.models.video
55

6-
The MViTv2 model is based on the
6+
The MViT V2 model is based on the
77
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
88
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
99
<https://arxiv.org/abs/2104.11227>`__ papers.
@@ -23,6 +23,6 @@ more details about this class.
2323
:toctree: generated/
2424
:template: function.rst
2525

26-
mvitv2_t
27-
mvitv2_s
28-
mvitv2_b
26+
mvit_v2_t
27+
mvit_v2_s
28+
mvit_v2_b

test/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,13 +309,13 @@ def _check_input_backprop(model, inputs):
309309
"image_size": 56,
310310
"input_shape": (1, 3, 56, 56),
311311
},
312-
"mvitv2_t": {
312+
"mvit_v2_t": {
313313
"input_shape": (1, 3, 16, 224, 224),
314314
},
315-
"mvitv2_s": {
315+
"mvit_v2_s": {
316316
"input_shape": (1, 3, 16, 224, 224),
317317
},
318-
"mvitv2_b": {
318+
"mvit_v2_b": {
319319
"input_shape": (1, 3, 32, 224, 224),
320320
},
321321
}

torchvision/models/video/mvitv2.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@
1212
from .._utils import _ovewrite_named_param
1313

1414

15-
__all__ = ["MViTv2", "MViTV2_T_Weights", "MViTV2_S_Weights", "MViTV2_B_Weights", "mvitv2_t", "mvitv2_s", "mvitv2_b"]
15+
__all__ = [
16+
"MViTV2",
17+
"MViT_V2_T_Weights",
18+
"MViT_V2_S_Weights",
19+
"MViT_V2_B_Weights",
20+
"mvit_v2_t",
21+
"mvit_v2_s",
22+
"mvit_v2_b",
23+
]
1624

1725

1826
# TODO: add weights
@@ -264,7 +272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
264272
return torch.cat((class_token, x), dim=1) + pos_embedding
265273

266274

267-
class MViTv2(nn.Module):
275+
class MViTV2(nn.Module):
268276
def __init__(
269277
self,
270278
spatial_size: Tuple[int, int],
@@ -283,7 +291,7 @@ def __init__(
283291
norm_layer: Optional[Callable[..., nn.Module]] = None,
284292
) -> None:
285293
"""
286-
MViTv2 main class.
294+
MViT V2 main class.
287295
288296
Args:
289297
spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
@@ -424,7 +432,7 @@ def _mvitv2(
424432
weights: Optional[WeightsEnum],
425433
progress: bool,
426434
**kwargs: Any,
427-
) -> MViTv2:
435+
) -> MViTV2:
428436
if weights is not None:
429437
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
430438
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
@@ -433,7 +441,7 @@ def _mvitv2(
433441
spatial_size = kwargs.pop("spatial_size", (224, 224))
434442
temporal_size = kwargs.pop("temporal_size", 16)
435443

436-
model = MViTv2(
444+
model = MViTV2(
437445
spatial_size=spatial_size,
438446
temporal_size=temporal_size,
439447
embed_channels=embed_channels,
@@ -452,29 +460,29 @@ def _mvitv2(
452460
return model
453461

454462

455-
class MViTV2_T_Weights(WeightsEnum):
463+
class MViT_V2_T_Weights(WeightsEnum):
456464
pass
457465

458466

459-
class MViTV2_S_Weights(WeightsEnum):
467+
class MViT_V2_S_Weights(WeightsEnum):
460468
pass
461469

462470

463-
class MViTV2_B_Weights(WeightsEnum):
471+
class MViT_V2_B_Weights(WeightsEnum):
464472
pass
465473

466474

467-
def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2:
475+
def mvit_v2_t(*, weights: Optional[MViT_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2:
468476
"""
469-
Constructs a tiny MViTv2 architecture from
477+
Constructs a tiny MViTV2 architecture from
470478
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
471479
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
472480
<https://arxiv.org/abs/2104.11227>`__.
473481
474482
Args:
475-
weights (:class:`~torchvision.models.video.MViTV2_T_Weights`, optional): The
483+
weights (:class:`~torchvision.models.video.MViT_V2_T_Weights`, optional): The
476484
pretrained weights to use. See
477-
:class:`~torchvision.models.video.MViTV2_T_Weights` below for
485+
:class:`~torchvision.models.video.MViT_V2_T_Weights` below for
478486
more details, and possible values. By default, no pre-trained
479487
weights are used.
480488
progress (bool, optional): If True, displays a progress bar of the
@@ -484,10 +492,10 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru
484492
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
485493
for more details about this class.
486494
487-
.. autoclass:: torchvision.models.video.MViTV2_T_Weights
495+
.. autoclass:: torchvision.models.video.MViT_V2_T_Weights
488496
:members:
489497
"""
490-
weights = MViTV2_T_Weights.verify(weights)
498+
weights = MViT_V2_T_Weights.verify(weights)
491499

492500
return _mvitv2(
493501
spatial_size=(224, 224),
@@ -502,17 +510,17 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru
502510
)
503511

504512

505-
def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2:
513+
def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2:
506514
"""
507-
Constructs a small MViTv2 architecture from
515+
Constructs a small MViTV2 architecture from
508516
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
509517
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
510518
<https://arxiv.org/abs/2104.11227>`__.
511519
512520
Args:
513-
weights (:class:`~torchvision.models.video.MViTV2_S_Weights`, optional): The
521+
weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
514522
pretrained weights to use. See
515-
:class:`~torchvision.models.video.MViTV2_S_Weights` below for
523+
:class:`~torchvision.models.video.MViT_V2_S_Weights` below for
516524
more details, and possible values. By default, no pre-trained
517525
weights are used.
518526
progress (bool, optional): If True, displays a progress bar of the
@@ -522,10 +530,10 @@ def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = Tru
522530
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
523531
for more details about this class.
524532
525-
.. autoclass:: torchvision.models.video.MViTV2_S_Weights
533+
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights
526534
:members:
527535
"""
528-
weights = MViTV2_S_Weights.verify(weights)
536+
weights = MViT_V2_S_Weights.verify(weights)
529537

530538
return _mvitv2(
531539
spatial_size=(224, 224),
@@ -540,17 +548,17 @@ def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = Tru
540548
)
541549

542550

543-
def mvitv2_b(*, weights: Optional[MViTV2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2:
551+
def mvit_v2_b(*, weights: Optional[MViT_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2:
544552
"""
545-
Constructs a base MViTv2 architecture from
553+
Constructs a base MViTV2 architecture from
546554
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
547555
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
548556
<https://arxiv.org/abs/2104.11227>`__.
549557
550558
Args:
551-
weights (:class:`~torchvision.models.video.MViTV2_B_Weights`, optional): The
559+
weights (:class:`~torchvision.models.video.MViT_V2_B_Weights`, optional): The
552560
pretrained weights to use. See
553-
:class:`~torchvision.models.video.MViTV2_B_Weights` below for
561+
:class:`~torchvision.models.video.MViT_V2_B_Weights` below for
554562
more details, and possible values. By default, no pre-trained
555563
weights are used.
556564
progress (bool, optional): If True, displays a progress bar of the
@@ -560,10 +568,10 @@ def mvitv2_b(*, weights: Optional[MViTV2_B_Weights] = None, progress: bool = Tru
560568
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
561569
for more details about this class.
562570
563-
.. autoclass:: torchvision.models.video.MViTV2_B_Weights
571+
.. autoclass:: torchvision.models.video.MViT_V2_B_Weights
564572
:members:
565573
"""
566-
weights = MViTV2_B_Weights.verify(weights)
574+
weights = MViT_V2_B_Weights.verify(weights)
567575

568576
return _mvitv2(
569577
spatial_size=(224, 224),

0 commit comments

Comments
 (0)