12
12
from .._utils import _ovewrite_named_param
13
13
14
14
15
- __all__ = ["MViTv2" , "MViTV2_T_Weights" , "MViTV2_S_Weights" , "MViTV2_B_Weights" , "mvitv2_t" , "mvitv2_s" , "mvitv2_b" ]
15
+ __all__ = [
16
+ "MViTV2" ,
17
+ "MViT_V2_T_Weights" ,
18
+ "MViT_V2_S_Weights" ,
19
+ "MViT_V2_B_Weights" ,
20
+ "mvit_v2_t" ,
21
+ "mvit_v2_s" ,
22
+ "mvit_v2_b" ,
23
+ ]
16
24
17
25
18
26
# TODO: add weights
@@ -264,7 +272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
264
272
return torch .cat ((class_token , x ), dim = 1 ) + pos_embedding
265
273
266
274
267
- class MViTv2 (nn .Module ):
275
+ class MViTV2 (nn .Module ):
268
276
def __init__ (
269
277
self ,
270
278
spatial_size : Tuple [int , int ],
@@ -283,7 +291,7 @@ def __init__(
283
291
norm_layer : Optional [Callable [..., nn .Module ]] = None ,
284
292
) -> None :
285
293
"""
286
- MViTv2 main class.
294
+ MViT V2 main class.
287
295
288
296
Args:
289
297
spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
@@ -424,7 +432,7 @@ def _mvitv2(
424
432
weights : Optional [WeightsEnum ],
425
433
progress : bool ,
426
434
** kwargs : Any ,
427
- ) -> MViTv2 :
435
+ ) -> MViTV2 :
428
436
if weights is not None :
429
437
_ovewrite_named_param (kwargs , "num_classes" , len (weights .meta ["categories" ]))
430
438
assert weights .meta ["min_size" ][0 ] == weights .meta ["min_size" ][1 ]
@@ -433,7 +441,7 @@ def _mvitv2(
433
441
spatial_size = kwargs .pop ("spatial_size" , (224 , 224 ))
434
442
temporal_size = kwargs .pop ("temporal_size" , 16 )
435
443
436
- model = MViTv2 (
444
+ model = MViTV2 (
437
445
spatial_size = spatial_size ,
438
446
temporal_size = temporal_size ,
439
447
embed_channels = embed_channels ,
@@ -452,29 +460,29 @@ def _mvitv2(
452
460
return model
453
461
454
462
455
- class MViTV2_T_Weights (WeightsEnum ):
463
+ class MViT_V2_T_Weights (WeightsEnum ):
456
464
pass
457
465
458
466
459
- class MViTV2_S_Weights (WeightsEnum ):
467
+ class MViT_V2_S_Weights (WeightsEnum ):
460
468
pass
461
469
462
470
463
- class MViTV2_B_Weights (WeightsEnum ):
471
+ class MViT_V2_B_Weights (WeightsEnum ):
464
472
pass
465
473
466
474
467
- def mvitv2_t (* , weights : Optional [MViTV2_T_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTv2 :
475
+ def mvit_v2_t (* , weights : Optional [MViT_V2_T_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTV2 :
468
476
"""
469
- Constructs a tiny MViTv2 architecture from
477
+ Constructs a tiny MViTV2 architecture from
470
478
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
471
479
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
472
480
<https://arxiv.org/abs/2104.11227>`__.
473
481
474
482
Args:
475
- weights (:class:`~torchvision.models.video.MViTV2_T_Weights `, optional): The
483
+ weights (:class:`~torchvision.models.video.MViT_V2_T_Weights `, optional): The
476
484
pretrained weights to use. See
477
- :class:`~torchvision.models.video.MViTV2_T_Weights ` below for
485
+ :class:`~torchvision.models.video.MViT_V2_T_Weights ` below for
478
486
more details, and possible values. By default, no pre-trained
479
487
weights are used.
480
488
progress (bool, optional): If True, displays a progress bar of the
@@ -484,10 +492,10 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru
484
492
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
485
493
for more details about this class.
486
494
487
- .. autoclass:: torchvision.models.video.MViTV2_T_Weights
495
+ .. autoclass:: torchvision.models.video.MViT_V2_T_Weights
488
496
:members:
489
497
"""
490
- weights = MViTV2_T_Weights .verify (weights )
498
+ weights = MViT_V2_T_Weights .verify (weights )
491
499
492
500
return _mvitv2 (
493
501
spatial_size = (224 , 224 ),
@@ -502,17 +510,17 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru
502
510
)
503
511
504
512
505
- def mvitv2_s (* , weights : Optional [MViTV2_S_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTv2 :
513
+ def mvit_v2_s (* , weights : Optional [MViT_V2_S_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTV2 :
506
514
"""
507
- Constructs a small MViTv2 architecture from
515
+ Constructs a small MViTV2 architecture from
508
516
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
509
517
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
510
518
<https://arxiv.org/abs/2104.11227>`__.
511
519
512
520
Args:
513
- weights (:class:`~torchvision.models.video.MViTV2_S_Weights `, optional): The
521
+ weights (:class:`~torchvision.models.video.MViT_V2_S_Weights `, optional): The
514
522
pretrained weights to use. See
515
- :class:`~torchvision.models.video.MViTV2_S_Weights ` below for
523
+ :class:`~torchvision.models.video.MViT_V2_S_Weights ` below for
516
524
more details, and possible values. By default, no pre-trained
517
525
weights are used.
518
526
progress (bool, optional): If True, displays a progress bar of the
@@ -522,10 +530,10 @@ def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = Tru
522
530
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
523
531
for more details about this class.
524
532
525
- .. autoclass:: torchvision.models.video.MViTV2_S_Weights
533
+ .. autoclass:: torchvision.models.video.MViT_V2_S_Weights
526
534
:members:
527
535
"""
528
- weights = MViTV2_S_Weights .verify (weights )
536
+ weights = MViT_V2_S_Weights .verify (weights )
529
537
530
538
return _mvitv2 (
531
539
spatial_size = (224 , 224 ),
@@ -540,17 +548,17 @@ def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = Tru
540
548
)
541
549
542
550
543
- def mvitv2_b (* , weights : Optional [MViTV2_B_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTv2 :
551
+ def mvit_v2_b (* , weights : Optional [MViT_V2_B_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTV2 :
544
552
"""
545
- Constructs a base MViTv2 architecture from
553
+ Constructs a base MViTV2 architecture from
546
554
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
547
555
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
548
556
<https://arxiv.org/abs/2104.11227>`__.
549
557
550
558
Args:
551
- weights (:class:`~torchvision.models.video.MViTV2_B_Weights `, optional): The
559
+ weights (:class:`~torchvision.models.video.MViT_V2_B_Weights `, optional): The
552
560
pretrained weights to use. See
553
- :class:`~torchvision.models.video.MViTV2_B_Weights ` below for
561
+ :class:`~torchvision.models.video.MViT_V2_B_Weights ` below for
554
562
more details, and possible values. By default, no pre-trained
555
563
weights are used.
556
564
progress (bool, optional): If True, displays a progress bar of the
@@ -560,10 +568,10 @@ def mvitv2_b(*, weights: Optional[MViTV2_B_Weights] = None, progress: bool = Tru
560
568
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
561
569
for more details about this class.
562
570
563
- .. autoclass:: torchvision.models.video.MViTV2_B_Weights
571
+ .. autoclass:: torchvision.models.video.MViT_V2_B_Weights
564
572
:members:
565
573
"""
566
- weights = MViTV2_B_Weights .verify (weights )
574
+ weights = MViT_V2_B_Weights .verify (weights )
567
575
568
576
return _mvitv2 (
569
577
spatial_size = (224 , 224 ),
0 commit comments