Skip to content

Commit 98d30fd

Browse files
committed
Fix docs and mypy
1 parent 78851e6 commit 98d30fd

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchvision/models/video/mvit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ def __init__(
290290
norm_layer(self.head_dim),
291291
)
292292

293-
self.rel_pos_h: Optional[nn.Module] = None
294-
self.rel_pos_w: Optional[nn.Module] = None
295-
self.rel_pos_t: Optional[nn.Module] = None
293+
self.rel_pos_h: Optional[nn.Parameter] = None
294+
self.rel_pos_w: Optional[nn.Parameter] = None
295+
self.rel_pos_t: Optional[nn.Parameter] = None
296296
if rel_pos:
297297
assert input_size[1] == input_size[2] # TODO: remove this limitation
298298
size = input_size[1]
@@ -471,6 +471,8 @@ def __init__(
471471
temporal_size (int): The temporal size ``T`` of the input.
472472
block_setting (sequence of MSBlockConfig): The Network structure.
473473
residual_pool (bool): If True, use MViTv2 pooling residual connection.
474+
rel_pos (bool): TODO
475+
dim_mul_in_att (bool): TODO
474476
dropout (float): Dropout rate. Default: 0.0.
475477
attention_dropout (float): Attention dropout rate. Default: 0.0.
476478
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
@@ -508,7 +510,7 @@ def __init__(
508510
# Spatio-Temporal Class Positional Encoding
509511
self.pos_encoding = PositionalEncoding(
510512
embed_size=block_setting[0].input_channels,
511-
spatial_size=tuple(input_size[1:]),
513+
spatial_size=(input_size[1], input_size[2]),
512514
temporal_size=input_size[0],
513515
rel_pos=rel_pos,
514516
)

0 commit comments

Comments
 (0)