Skip to content

Commit 22c9850

Browse files
committed
Fix docs, mypy and linter
1 parent 78851e6 commit 22c9850

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torchvision/models/video/mvit.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
# Reference: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932#
2828

29+
2930
def get_rel_pos(rel_pos: torch.Tensor, d: int) -> torch.Tensor:
3031
if rel_pos.shape[0] == d:
3132
return rel_pos
@@ -290,9 +291,9 @@ def __init__(
290291
norm_layer(self.head_dim),
291292
)
292293

293-
self.rel_pos_h: Optional[nn.Module] = None
294-
self.rel_pos_w: Optional[nn.Module] = None
295-
self.rel_pos_t: Optional[nn.Module] = None
294+
self.rel_pos_h: Optional[nn.Parameter] = None
295+
self.rel_pos_w: Optional[nn.Parameter] = None
296+
self.rel_pos_t: Optional[nn.Parameter] = None
296297
if rel_pos:
297298
assert input_size[1] == input_size[2] # TODO: remove this limitation
298299
size = input_size[1]
@@ -471,6 +472,8 @@ def __init__(
471472
temporal_size (int): The temporal size ``T`` of the input.
472473
block_setting (sequence of MSBlockConfig): The Network structure.
473474
residual_pool (bool): If True, use MViTv2 pooling residual connection.
475+
rel_pos (bool): TODO
476+
dim_mul_in_att (bool): TODO
474477
dropout (float): Dropout rate. Default: 0.0.
475478
attention_dropout (float): Attention dropout rate. Default: 0.0.
476479
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
@@ -508,7 +511,7 @@ def __init__(
508511
# Spatio-Temporal Class Positional Encoding
509512
self.pos_encoding = PositionalEncoding(
510513
embed_size=block_setting[0].input_channels,
511-
spatial_size=tuple(input_size[1:]),
514+
spatial_size=(input_size[1], input_size[2]),
512515
temporal_size=input_size[0],
513516
rel_pos=rel_pos,
514517
)

0 commit comments

Comments
 (0)