Skip to content

Commit c5b2839

Browse files
committed
rebased + addresed comments
1 parent f15fd92 commit c5b2839

File tree

73 files changed

+106
-137
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+106
-137
lines changed
0 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
939 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
396 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -594,14 +594,6 @@ def test_vitc_models(model_fn, dev):
594594
test_classification_model(model_fn, dev)
595595

596596

597-
@pytest.mark.parametrize(
598-
"model_fn", [models.max_vit_T_224, models.max_vit_S_224, models.max_vit_B_224, models.max_vit_L_224]
599-
)
600-
@pytest.mark.parametrize("dev", cpu_and_gpu())
601-
def test_max_vit(model_fn, dev):
602-
test_classification_model(model_fn, dev)
603-
604-
605597
@pytest.mark.parametrize("model_fn", list_model_fns(models))
606598
@pytest.mark.parametrize("dev", cpu_and_gpu())
607599
def test_classification_model(model_fn, dev):

torchvision/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from .mobilenet import *
99
from .regnet import *
1010
from .resnet import *
11+
from .maxvit import *
1112
from .shufflenetv2 import *
1213
from .squeezenet import *
1314
from .vgg import *
1415
from .vision_transformer import *
1516
from .swin_transformer import *
16-
from .maxvit import *
1717
from . import detection, optical_flow, quantization, segmentation, video
1818
from ._api import get_model, get_model_weights, get_weight, list_models

torchvision/models/maxvit.py

Lines changed: 105 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import math
2-
from typing import Any, Callable, List, OrderedDict, Sequence, Tuple
2+
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Tuple
33

44
import numpy as np
55
import torch
66
import torch.nn.functional as F
77
from torch import nn, Tensor
8+
from torchvision.models._api import register_model, WeightsEnum
9+
from torchvision.models._utils import _ovewrite_named_param
810
from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
911
from torchvision.ops.stochastic_depth import StochasticDepth
12+
from torchvision.utils import _log_api_usage_once
1013

1114

1215
def get_relative_position_index(height: int, width: int) -> torch.Tensor:
@@ -20,20 +23,6 @@ def get_relative_position_index(height: int, width: int) -> torch.Tensor:
2023
return relative_coords.sum(-1)
2124

2225

23-
class GeluWrapper(nn.Module):
24-
"""
25-
Gelu wrapper to make it compatible with `ConvNormActivation2D` which passed inplace=True
26-
to the activation function construction.
27-
"""
28-
29-
def __init__(self, **kwargs) -> None:
30-
super().__init__()
31-
self._op = F.gelu
32-
33-
def forward(self, x: Tensor) -> Tensor:
34-
return self._op(x)
35-
36-
3726
class MBConv(nn.Module):
3827
def __init__(
3928
self,
@@ -65,20 +54,28 @@ def __init__(
6554
_layers = OrderedDict()
6655
_layers["pre_norm"] = normalization_fn(in_channels)
6756
_layers["conv_a"] = Conv2dNormActivation(
68-
in_channels, mid_channels, 1, 1, 0, activation_layer=activation_fn, norm_layer=normalization_fn
57+
in_channels,
58+
mid_channels,
59+
kernel_size=1,
60+
stride=1,
61+
padding=0,
62+
activation_layer=activation_fn,
63+
norm_layer=normalization_fn,
64+
inplace=None,
6965
)
7066
_layers["conv_b"] = Conv2dNormActivation(
7167
mid_channels,
7268
mid_channels,
73-
3,
74-
stride,
75-
1,
69+
kernel_size=3,
70+
stride=stride,
71+
padding=1,
7672
activation_layer=activation_fn,
7773
norm_layer=normalization_fn,
7874
groups=mid_channels,
75+
inplace=None,
7976
)
8077
_layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels)
81-
_layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=False)
78+
_layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
8279

8380
self.layers = nn.Sequential(_layers)
8481

@@ -116,14 +113,13 @@ def __init__(
116113
# initialize with truncated normal the bias
117114
self.positional_bias.data.normal_(mean=0, std=0.02)
118115

119-
def _get_relative_positional_bias(self) -> torch.Tensor:
116+
def get_relative_positional_bias(self) -> torch.Tensor:
120117
bias_index = self.relative_position_index.view(-1) # type: ignore
121118
relative_bias = self.positional_bias[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore
122119
relative_bias = relative_bias.permute(2, 0, 1).contiguous()
123120
return relative_bias.unsqueeze(0)
124121

125122
def forward(self, x: Tensor) -> Tensor:
126-
# X, Y and stand for X-axis group dim, Y-axis group dim
127123
B, G, P, D = x.shape
128124
H, DH = self.n_heads, self.head_dim
129125

@@ -135,9 +131,8 @@ def forward(self, x: Tensor) -> Tensor:
135131
v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
136132

137133
k = k * self.scale_factor
138-
# X, Y and stand for X-axis group dim, Y-axis group dim
139134
dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
140-
pos_bias = self._get_relative_positional_bias()
135+
pos_bias = self.get_relative_positional_bias()
141136

142137
dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
143138

@@ -204,34 +199,6 @@ def forward(self, x: Tensor) -> Tensor:
204199
return x
205200

206201

207-
class MLP(nn.Module):
208-
def __init__(
209-
self,
210-
in_dim: int,
211-
hidden_dim: int,
212-
activation_fn: Callable[..., nn.Module],
213-
normalization_fn: Callable[..., nn.Module],
214-
dropout: float,
215-
) -> None:
216-
super().__init__()
217-
self.in_dim = in_dim
218-
self.hidden_dim = hidden_dim
219-
self.activation_fn = activation_fn
220-
self.normalization_fn = normalization_fn
221-
self.dropout = dropout
222-
223-
self.layers = nn.Sequential(
224-
self.normalization_fn(in_dim),
225-
nn.Linear(in_dim, hidden_dim),
226-
self.activation_fn(),
227-
nn.Linear(hidden_dim, in_dim),
228-
nn.Dropout(dropout),
229-
)
230-
231-
def forward(self, x: Tensor) -> Tensor:
232-
return x + self.layers(x)
233-
234-
235202
class PartitionAttentionLayer(nn.Module):
236203
def __init__(
237204
self,
@@ -282,16 +249,23 @@ def __init__(
282249
nn.Dropout(attn_dropout),
283250
)
284251

285-
self.mlp_layer = MLP(in_channels, in_channels * mlp_ratio, activation_fn, normalization_fn, mlp_dropout)
252+
# pre-normalization similar to transformer layers
253+
self.mlp_layer = nn.Sequential(
254+
nn.LayerNorm(in_channels),
255+
nn.Linear(in_channels, in_channels * mlp_ratio),
256+
activation_fn(),
257+
nn.Linear(in_channels * mlp_ratio, in_channels),
258+
nn.Dropout(mlp_dropout),
259+
)
286260

287261
# layer scale factors
288262
self.attn_layer_scale = nn.parameter.Parameter(torch.ones(in_channels) * 1e-6)
289263
self.mlp_layer_scale = nn.parameter.Parameter(torch.ones(in_channels) * 1e-6)
290264

291265
def forward(self, x: Tensor) -> Tensor:
292266
x = self.partition_op(x)
293-
x = self.attn_layer(x) * self.attn_layer_scale
294-
x = self.mlp_layer(x) * self.mlp_layer_scale
267+
x = x + self.attn_layer(x) * self.attn_layer_scale
268+
x = x + self.mlp_layer(x) * self.mlp_layer_scale
295269
x = self.departition_op(x)
296270
return x
297271

@@ -386,9 +360,8 @@ def __init__(
386360
p_stochastic: List[float],
387361
) -> None:
388362
super().__init__()
389-
assert (
390-
len(p_stochastic) == n_layers
391-
), f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}."
363+
if not len(p_stochastic) == n_layers:
364+
raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
392365

393366
self.layers = nn.ModuleList()
394367
# account for the first stride of the first layer
@@ -424,11 +397,12 @@ def forward(self, x: Tensor) -> Tensor:
424397
class MaxVit(nn.Module):
425398
def __init__(
426399
self,
400+
# input size parameters
401+
input_size: Tuple[int, int],
427402
# stem and task parameters
428403
input_channels: int,
429404
stem_channels: int,
430-
input_size: Tuple[int, int],
431-
out_classes: int,
405+
num_classes: int,
432406
# block parameters
433407
block_channels: List[int],
434408
block_layers: List[int],
@@ -450,6 +424,7 @@ def __init__(
450424
partition_size: int,
451425
) -> None:
452426
super().__init__()
427+
_log_api_usage_once(self)
453428

454429
# stem
455430
self.stem = nn.Sequential(
@@ -500,7 +475,7 @@ def __init__(
500475
self.classifier = nn.Sequential(
501476
nn.AdaptiveAvgPool2d(1),
502477
nn.Flatten(),
503-
nn.Linear(block_channels[-1], out_classes, bias=False),
478+
nn.Linear(block_channels[-1], num_classes, bias=False),
504479
)
505480

506481
def forward(self, x: Tensor) -> Tensor:
@@ -511,85 +486,87 @@ def forward(self, x: Tensor) -> Tensor:
511486
return x
512487

513488

514-
def max_vit_T_224(num_classes: int) -> MaxVit:
515-
return MaxVit(
516-
input_channels=3,
517-
stem_channels=64,
518-
input_size=(224, 224),
519-
out_classes=num_classes,
520-
block_channels=[64, 128, 256, 512],
521-
block_layers=[2, 2, 5, 2],
522-
stochastic_depth_prob=0.2,
523-
squeeze_ratio=0.25,
524-
expansion_ratio=4.0,
525-
normalization_fn=nn.BatchNorm2d,
526-
activation_fn=GeluWrapper,
527-
head_dim=32,
528-
mlp_ratio=2,
529-
mlp_dropout=0.0,
530-
attn_dropout=0.0,
531-
partition_size=7,
489+
def _maxvit(
490+
# stem and task parameters
491+
stem_channels: int,
492+
num_classes: int,
493+
# block parameters
494+
block_channels: List[int],
495+
block_layers: List[int],
496+
stochastic_depth_prob: float,
497+
# conv parameters
498+
squeeze_ratio: float,
499+
expansion_ratio: float,
500+
# conv + transformer parameters
501+
# normalization_fn is applied only to the conv layers
502+
# activation_fn is applied both to conv and transformer layers
503+
normalization_fn: Callable[..., nn.Module],
504+
activation_fn: Callable[..., nn.Module],
505+
# transformer parameters
506+
head_dim: int,
507+
mlp_ratio: int,
508+
mlp_dropout: float,
509+
attn_dropout: float,
510+
# partitioning parameters
511+
partition_size: int,
512+
# Weights API
513+
weights: Optional[WeightsEnum],
514+
progress: bool,
515+
# kwargs,
516+
**kwargs,
517+
) -> MaxVit:
518+
if weights is not None:
519+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
520+
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
521+
_ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"][0])
522+
_ovewrite_named_param(kwargs, "input_channels", weights.meta["input_channels"])
523+
524+
input_size = kwargs.pop("input_size", (224, 224))
525+
input_channels = kwargs.pop("input_channels", 3)
526+
527+
model = MaxVit(
528+
input_channels=input_channels,
529+
stem_channels=stem_channels,
530+
num_classes=num_classes,
531+
block_channels=block_channels,
532+
block_layers=block_layers,
533+
stochastic_depth_prob=stochastic_depth_prob,
534+
squeeze_ratio=squeeze_ratio,
535+
expansion_ratio=expansion_ratio,
536+
normalization_fn=normalization_fn,
537+
activation_fn=activation_fn,
538+
head_dim=head_dim,
539+
mlp_ratio=mlp_ratio,
540+
mlp_dropout=mlp_dropout,
541+
attn_dropout=attn_dropout,
542+
partition_size=partition_size,
543+
input_size=input_size,
544+
**kwargs,
532545
)
533546

547+
if weights is not None:
548+
model.load_state_dict(weights.get_state_dict(progress=progress))
534549

535-
def max_vit_S_224(num_classes: int) -> MaxVit:
536-
return MaxVit(
537-
input_channels=3,
538-
stem_channels=64,
539-
input_size=(224, 224),
540-
out_classes=num_classes,
541-
block_channels=[96, 192, 384, 768],
542-
block_layers=[2, 2, 5, 2],
543-
stochastic_depth_prob=0.3,
544-
squeeze_ratio=0.25,
545-
expansion_ratio=4.0,
546-
normalization_fn=nn.BatchNorm2d,
547-
activation_fn=GeluWrapper,
548-
head_dim=32,
549-
mlp_ratio=2,
550-
mlp_dropout=0.0,
551-
attn_dropout=0.0,
552-
partition_size=7,
553-
)
550+
return model
554551

555552

556-
def max_vit_B_224(num_classes: int) -> MaxVit:
557-
return MaxVit(
558-
input_channels=3,
553+
@register_model(name="maxvit_t")
554+
def maxvit_t(*, weights: Optional[WeightsEnum] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
555+
return _maxvit(
559556
stem_channels=64,
560-
input_size=(224, 224),
561-
out_classes=num_classes,
562-
block_channels=[96, 192, 384, 768],
563-
block_layers=[2, 6, 14, 2],
564-
stochastic_depth_prob=0.4,
565-
squeeze_ratio=0.25,
566-
expansion_ratio=4.0,
567-
normalization_fn=nn.BatchNorm2d,
568-
activation_fn=GeluWrapper,
569-
head_dim=32,
570-
mlp_ratio=2,
571-
mlp_dropout=0.0,
572-
attn_dropout=0.0,
573-
partition_size=7,
574-
)
575-
576-
577-
def max_vit_L_224(num_classes: int) -> MaxVit:
578-
return MaxVit(
579-
input_channels=3,
580-
stem_channels=128,
581-
input_size=(224, 224),
582-
out_classes=num_classes,
583-
block_channels=[128, 256, 512, 1024],
584-
block_layers=[2, 6, 14, 2],
585-
stochastic_depth_prob=0.6,
557+
block_channels=[64, 128, 256, 512],
558+
block_layers=[2, 2, 5, 2],
559+
stochastic_depth_prob=0.2,
586560
squeeze_ratio=0.25,
587561
expansion_ratio=4.0,
588562
normalization_fn=nn.BatchNorm2d,
589-
activation_fn=GeluWrapper,
563+
activation_fn=nn.GELU,
590564
head_dim=32,
591565
mlp_ratio=2,
592566
mlp_dropout=0.0,
593567
attn_dropout=0.0,
594568
partition_size=7,
569+
weights=weights,
570+
progress=progress,
571+
**kwargs,
595572
)

0 commit comments

Comments
 (0)