@@ -290,9 +290,9 @@ def __init__(
290
290
norm_layer (self .head_dim ),
291
291
)
292
292
293
- self .rel_pos_h : Optional [nn .Module ] = None
294
- self .rel_pos_w : Optional [nn .Module ] = None
295
- self .rel_pos_t : Optional [nn .Module ] = None
293
+ self .rel_pos_h : Optional [nn .Parameter ] = None
294
+ self .rel_pos_w : Optional [nn .Parameter ] = None
295
+ self .rel_pos_t : Optional [nn .Parameter ] = None
296
296
if rel_pos :
297
297
assert input_size [1 ] == input_size [2 ] # TODO: remove this limitation
298
298
size = input_size [1 ]
@@ -471,6 +471,8 @@ def __init__(
471
471
temporal_size (int): The temporal size ``T`` of the input.
472
472
block_setting (sequence of MSBlockConfig): The Network structure.
473
473
residual_pool (bool): If True, use MViTv2 pooling residual connection.
474
+ rel_pos (bool): TODO
475
+ dim_mul_in_att (bool): TODO
474
476
dropout (float): Dropout rate. Default: 0.0.
475
477
attention_dropout (float): Attention dropout rate. Default: 0.0.
476
478
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
@@ -508,7 +510,7 @@ def __init__(
508
510
# Spatio-Temporal Class Positional Encoding
509
511
self .pos_encoding = PositionalEncoding (
510
512
embed_size = block_setting [0 ].input_channels ,
511
- spatial_size = tuple (input_size [1 : ]),
513
+ spatial_size = (input_size [1 ], input_size [ 2 ]),
512
514
temporal_size = input_size [0 ],
513
515
rel_pos = rel_pos ,
514
516
)
0 commit comments