1
1
import math
2
- from typing import Any , Callable , List , OrderedDict , Sequence , Tuple
2
+ from typing import Any , Callable , List , Optional , OrderedDict , Sequence , Tuple
3
3
4
4
import numpy as np
5
5
import torch
6
6
import torch .nn .functional as F
7
7
from torch import nn , Tensor
8
+ from torchvision .models ._api import register_model , WeightsEnum
9
+ from torchvision .models ._utils import _ovewrite_named_param
8
10
from torchvision .ops .misc import Conv2dNormActivation , SqueezeExcitation
9
11
from torchvision .ops .stochastic_depth import StochasticDepth
12
+ from torchvision .utils import _log_api_usage_once
10
13
11
14
12
15
def get_relative_position_index (height : int , width : int ) -> torch .Tensor :
@@ -20,20 +23,6 @@ def get_relative_position_index(height: int, width: int) -> torch.Tensor:
20
23
return relative_coords .sum (- 1 )
21
24
22
25
23
- class GeluWrapper (nn .Module ):
24
- """
25
- Gelu wrapper to make it compatible with `ConvNormActivation2D` which passed inplace=True
26
- to the activation function construction.
27
- """
28
-
29
- def __init__ (self , ** kwargs ) -> None :
30
- super ().__init__ ()
31
- self ._op = F .gelu
32
-
33
- def forward (self , x : Tensor ) -> Tensor :
34
- return self ._op (x )
35
-
36
-
37
26
class MBConv (nn .Module ):
38
27
def __init__ (
39
28
self ,
@@ -65,20 +54,28 @@ def __init__(
65
54
_layers = OrderedDict ()
66
55
_layers ["pre_norm" ] = normalization_fn (in_channels )
67
56
_layers ["conv_a" ] = Conv2dNormActivation (
68
- in_channels , mid_channels , 1 , 1 , 0 , activation_layer = activation_fn , norm_layer = normalization_fn
57
+ in_channels ,
58
+ mid_channels ,
59
+ kernel_size = 1 ,
60
+ stride = 1 ,
61
+ padding = 0 ,
62
+ activation_layer = activation_fn ,
63
+ norm_layer = normalization_fn ,
64
+ inplace = None ,
69
65
)
70
66
_layers ["conv_b" ] = Conv2dNormActivation (
71
67
mid_channels ,
72
68
mid_channels ,
73
- 3 ,
74
- stride ,
75
- 1 ,
69
+ kernel_size = 3 ,
70
+ stride = stride ,
71
+ padding = 1 ,
76
72
activation_layer = activation_fn ,
77
73
norm_layer = normalization_fn ,
78
74
groups = mid_channels ,
75
+ inplace = None ,
79
76
)
80
77
_layers ["squeeze_excitation" ] = SqueezeExcitation (mid_channels , sqz_channels )
81
- _layers ["conv_c" ] = nn .Conv2d (in_channels = mid_channels , out_channels = out_channels , kernel_size = 1 , bias = False )
78
+ _layers ["conv_c" ] = nn .Conv2d (in_channels = mid_channels , out_channels = out_channels , kernel_size = 1 , bias = True )
82
79
83
80
self .layers = nn .Sequential (_layers )
84
81
@@ -116,14 +113,13 @@ def __init__(
116
113
# initialize with truncated normal the bias
117
114
self .positional_bias .data .normal_ (mean = 0 , std = 0.02 )
118
115
119
- def _get_relative_positional_bias (self ) -> torch .Tensor :
116
+ def get_relative_positional_bias (self ) -> torch .Tensor :
120
117
bias_index = self .relative_position_index .view (- 1 ) # type: ignore
121
118
relative_bias = self .positional_bias [bias_index ].view (self .max_seq_len , self .max_seq_len , - 1 ) # type: ignore
122
119
relative_bias = relative_bias .permute (2 , 0 , 1 ).contiguous ()
123
120
return relative_bias .unsqueeze (0 )
124
121
125
122
def forward (self , x : Tensor ) -> Tensor :
126
- # X, Y and stand for X-axis group dim, Y-axis group dim
127
123
B , G , P , D = x .shape
128
124
H , DH = self .n_heads , self .head_dim
129
125
@@ -135,9 +131,8 @@ def forward(self, x: Tensor) -> Tensor:
135
131
v = v .reshape (B , G , P , H , DH ).permute (0 , 1 , 3 , 2 , 4 )
136
132
137
133
k = k * self .scale_factor
138
- # X, Y and stand for X-axis group dim, Y-axis group dim
139
134
dot_prod = torch .einsum ("B G H I D, B G H J D -> B G H I J" , q , k )
140
- pos_bias = self ._get_relative_positional_bias ()
135
+ pos_bias = self .get_relative_positional_bias ()
141
136
142
137
dot_prod = F .softmax (dot_prod + pos_bias , dim = - 1 )
143
138
@@ -204,34 +199,6 @@ def forward(self, x: Tensor) -> Tensor:
204
199
return x
205
200
206
201
207
- class MLP (nn .Module ):
208
- def __init__ (
209
- self ,
210
- in_dim : int ,
211
- hidden_dim : int ,
212
- activation_fn : Callable [..., nn .Module ],
213
- normalization_fn : Callable [..., nn .Module ],
214
- dropout : float ,
215
- ) -> None :
216
- super ().__init__ ()
217
- self .in_dim = in_dim
218
- self .hidden_dim = hidden_dim
219
- self .activation_fn = activation_fn
220
- self .normalization_fn = normalization_fn
221
- self .dropout = dropout
222
-
223
- self .layers = nn .Sequential (
224
- self .normalization_fn (in_dim ),
225
- nn .Linear (in_dim , hidden_dim ),
226
- self .activation_fn (),
227
- nn .Linear (hidden_dim , in_dim ),
228
- nn .Dropout (dropout ),
229
- )
230
-
231
- def forward (self , x : Tensor ) -> Tensor :
232
- return x + self .layers (x )
233
-
234
-
235
202
class PartitionAttentionLayer (nn .Module ):
236
203
def __init__ (
237
204
self ,
@@ -282,16 +249,23 @@ def __init__(
282
249
nn .Dropout (attn_dropout ),
283
250
)
284
251
285
- self .mlp_layer = MLP (in_channels , in_channels * mlp_ratio , activation_fn , normalization_fn , mlp_dropout )
252
+ # pre-normalization similar to transformer layers
253
+ self .mlp_layer = nn .Sequential (
254
+ nn .LayerNorm (in_channels ),
255
+ nn .Linear (in_channels , in_channels * mlp_ratio ),
256
+ activation_fn (),
257
+ nn .Linear (in_channels * mlp_ratio , in_channels ),
258
+ nn .Dropout (mlp_dropout ),
259
+ )
286
260
287
261
# layer scale factors
288
262
self .attn_layer_scale = nn .parameter .Parameter (torch .ones (in_channels ) * 1e-6 )
289
263
self .mlp_layer_scale = nn .parameter .Parameter (torch .ones (in_channels ) * 1e-6 )
290
264
291
265
def forward (self , x : Tensor ) -> Tensor :
292
266
x = self .partition_op (x )
293
- x = self .attn_layer (x ) * self .attn_layer_scale
294
- x = self .mlp_layer (x ) * self .mlp_layer_scale
267
+ x = x + self .attn_layer (x ) * self .attn_layer_scale
268
+ x = x + self .mlp_layer (x ) * self .mlp_layer_scale
295
269
x = self .departition_op (x )
296
270
return x
297
271
@@ -386,9 +360,8 @@ def __init__(
386
360
p_stochastic : List [float ],
387
361
) -> None :
388
362
super ().__init__ ()
389
- assert (
390
- len (p_stochastic ) == n_layers
391
- ), f"p_stochastic must have length n_layers={ n_layers } , got p_stochastic={ p_stochastic } ."
363
+ if not len (p_stochastic ) == n_layers :
364
+ raise ValueError (f"p_stochastic must have length n_layers={ n_layers } , got p_stochastic={ p_stochastic } ." )
392
365
393
366
self .layers = nn .ModuleList ()
394
367
# account for the first stride of the first layer
@@ -424,11 +397,12 @@ def forward(self, x: Tensor) -> Tensor:
424
397
class MaxVit (nn .Module ):
425
398
def __init__ (
426
399
self ,
400
+ # input size parameters
401
+ input_size : Tuple [int , int ],
427
402
# stem and task parameters
428
403
input_channels : int ,
429
404
stem_channels : int ,
430
- input_size : Tuple [int , int ],
431
- out_classes : int ,
405
+ num_classes : int ,
432
406
# block parameters
433
407
block_channels : List [int ],
434
408
block_layers : List [int ],
@@ -450,6 +424,7 @@ def __init__(
450
424
partition_size : int ,
451
425
) -> None :
452
426
super ().__init__ ()
427
+ _log_api_usage_once (self )
453
428
454
429
# stem
455
430
self .stem = nn .Sequential (
@@ -500,7 +475,7 @@ def __init__(
500
475
self .classifier = nn .Sequential (
501
476
nn .AdaptiveAvgPool2d (1 ),
502
477
nn .Flatten (),
503
- nn .Linear (block_channels [- 1 ], out_classes , bias = False ),
478
+ nn .Linear (block_channels [- 1 ], num_classes , bias = False ),
504
479
)
505
480
506
481
def forward (self , x : Tensor ) -> Tensor :
@@ -511,85 +486,87 @@ def forward(self, x: Tensor) -> Tensor:
511
486
return x
512
487
513
488
514
- def max_vit_T_224 (num_classes : int ) -> MaxVit :
515
- return MaxVit (
516
- input_channels = 3 ,
517
- stem_channels = 64 ,
518
- input_size = (224 , 224 ),
519
- out_classes = num_classes ,
520
- block_channels = [64 , 128 , 256 , 512 ],
521
- block_layers = [2 , 2 , 5 , 2 ],
522
- stochastic_depth_prob = 0.2 ,
523
- squeeze_ratio = 0.25 ,
524
- expansion_ratio = 4.0 ,
525
- normalization_fn = nn .BatchNorm2d ,
526
- activation_fn = GeluWrapper ,
527
- head_dim = 32 ,
528
- mlp_ratio = 2 ,
529
- mlp_dropout = 0.0 ,
530
- attn_dropout = 0.0 ,
531
- partition_size = 7 ,
489
+ def _maxvit (
490
+ # stem and task parameters
491
+ stem_channels : int ,
492
+ num_classes : int ,
493
+ # block parameters
494
+ block_channels : List [int ],
495
+ block_layers : List [int ],
496
+ stochastic_depth_prob : float ,
497
+ # conv parameters
498
+ squeeze_ratio : float ,
499
+ expansion_ratio : float ,
500
+ # conv + transformer parameters
501
+ # normalization_fn is applied only to the conv layers
502
+ # activation_fn is applied both to conv and transformer layers
503
+ normalization_fn : Callable [..., nn .Module ],
504
+ activation_fn : Callable [..., nn .Module ],
505
+ # transformer parameters
506
+ head_dim : int ,
507
+ mlp_ratio : int ,
508
+ mlp_dropout : float ,
509
+ attn_dropout : float ,
510
+ # partitioning parameters
511
+ partition_size : int ,
512
+ # Weights API
513
+ weights : Optional [WeightsEnum ],
514
+ progress : bool ,
515
+ # kwargs,
516
+ ** kwargs ,
517
+ ) -> MaxVit :
518
+ if weights is not None :
519
+ _ovewrite_named_param (kwargs , "num_classes" , len (weights .meta ["categories" ]))
520
+ assert weights .meta ["min_size" ][0 ] == weights .meta ["min_size" ][1 ]
521
+ _ovewrite_named_param (kwargs , "input_size" , weights .meta ["min_size" ][0 ])
522
+ _ovewrite_named_param (kwargs , "input_channels" , weights .meta ["input_channels" ])
523
+
524
+ input_size = kwargs .pop ("input_size" , (224 , 224 ))
525
+ input_channels = kwargs .pop ("input_channels" , 3 )
526
+
527
+ model = MaxVit (
528
+ input_channels = input_channels ,
529
+ stem_channels = stem_channels ,
530
+ num_classes = num_classes ,
531
+ block_channels = block_channels ,
532
+ block_layers = block_layers ,
533
+ stochastic_depth_prob = stochastic_depth_prob ,
534
+ squeeze_ratio = squeeze_ratio ,
535
+ expansion_ratio = expansion_ratio ,
536
+ normalization_fn = normalization_fn ,
537
+ activation_fn = activation_fn ,
538
+ head_dim = head_dim ,
539
+ mlp_ratio = mlp_ratio ,
540
+ mlp_dropout = mlp_dropout ,
541
+ attn_dropout = attn_dropout ,
542
+ partition_size = partition_size ,
543
+ input_size = input_size ,
544
+ ** kwargs ,
532
545
)
533
546
547
+ if weights is not None :
548
+ model .load_state_dict (weights .get_state_dict (progress = progress ))
534
549
535
- def max_vit_S_224 (num_classes : int ) -> MaxVit :
536
- return MaxVit (
537
- input_channels = 3 ,
538
- stem_channels = 64 ,
539
- input_size = (224 , 224 ),
540
- out_classes = num_classes ,
541
- block_channels = [96 , 192 , 384 , 768 ],
542
- block_layers = [2 , 2 , 5 , 2 ],
543
- stochastic_depth_prob = 0.3 ,
544
- squeeze_ratio = 0.25 ,
545
- expansion_ratio = 4.0 ,
546
- normalization_fn = nn .BatchNorm2d ,
547
- activation_fn = GeluWrapper ,
548
- head_dim = 32 ,
549
- mlp_ratio = 2 ,
550
- mlp_dropout = 0.0 ,
551
- attn_dropout = 0.0 ,
552
- partition_size = 7 ,
553
- )
550
+ return model
554
551
555
552
556
- def max_vit_B_224 ( num_classes : int ) -> MaxVit :
557
- return MaxVit (
558
- input_channels = 3 ,
553
+ @ register_model ( name = "maxvit_t" )
554
+ def maxvit_t ( * , weights : Optional [ WeightsEnum ] = None , progress : bool = True , ** kwargs : Any ) -> MaxVit :
555
+ return _maxvit (
559
556
stem_channels = 64 ,
560
- input_size = (224 , 224 ),
561
- out_classes = num_classes ,
562
- block_channels = [96 , 192 , 384 , 768 ],
563
- block_layers = [2 , 6 , 14 , 2 ],
564
- stochastic_depth_prob = 0.4 ,
565
- squeeze_ratio = 0.25 ,
566
- expansion_ratio = 4.0 ,
567
- normalization_fn = nn .BatchNorm2d ,
568
- activation_fn = GeluWrapper ,
569
- head_dim = 32 ,
570
- mlp_ratio = 2 ,
571
- mlp_dropout = 0.0 ,
572
- attn_dropout = 0.0 ,
573
- partition_size = 7 ,
574
- )
575
-
576
-
577
- def max_vit_L_224 (num_classes : int ) -> MaxVit :
578
- return MaxVit (
579
- input_channels = 3 ,
580
- stem_channels = 128 ,
581
- input_size = (224 , 224 ),
582
- out_classes = num_classes ,
583
- block_channels = [128 , 256 , 512 , 1024 ],
584
- block_layers = [2 , 6 , 14 , 2 ],
585
- stochastic_depth_prob = 0.6 ,
557
+ block_channels = [64 , 128 , 256 , 512 ],
558
+ block_layers = [2 , 2 , 5 , 2 ],
559
+ stochastic_depth_prob = 0.2 ,
586
560
squeeze_ratio = 0.25 ,
587
561
expansion_ratio = 4.0 ,
588
562
normalization_fn = nn .BatchNorm2d ,
589
- activation_fn = GeluWrapper ,
563
+ activation_fn = nn . GELU ,
590
564
head_dim = 32 ,
591
565
mlp_ratio = 2 ,
592
566
mlp_dropout = 0.0 ,
593
567
attn_dropout = 0.0 ,
594
568
partition_size = 7 ,
569
+ weights = weights ,
570
+ progress = progress ,
571
+ ** kwargs ,
595
572
)
0 commit comments