|
26 | 26 |
|
27 | 27 | # Reference: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932#
|
28 | 28 |
|
| 29 | + |
29 | 30 | def get_rel_pos(rel_pos: torch.Tensor, d: int) -> torch.Tensor:
|
30 | 31 | if rel_pos.shape[0] == d:
|
31 | 32 | return rel_pos
|
@@ -290,9 +291,9 @@ def __init__(
|
290 | 291 | norm_layer(self.head_dim),
|
291 | 292 | )
|
292 | 293 |
|
293 |
| - self.rel_pos_h: Optional[nn.Module] = None |
294 |
| - self.rel_pos_w: Optional[nn.Module] = None |
295 |
| - self.rel_pos_t: Optional[nn.Module] = None |
| 294 | + self.rel_pos_h: Optional[nn.Parameter] = None |
| 295 | + self.rel_pos_w: Optional[nn.Parameter] = None |
| 296 | + self.rel_pos_t: Optional[nn.Parameter] = None |
296 | 297 | if rel_pos:
|
297 | 298 | assert input_size[1] == input_size[2] # TODO: remove this limitation
|
298 | 299 | size = input_size[1]
|
@@ -471,6 +472,8 @@ def __init__(
|
471 | 472 | temporal_size (int): The temporal size ``T`` of the input.
|
472 | 473 | block_setting (sequence of MSBlockConfig): The Network structure.
|
473 | 474 | residual_pool (bool): If True, use MViTv2 pooling residual connection.
|
| 475 | + rel_pos (bool): TODO |
| 476 | + dim_mul_in_att (bool): TODO |
474 | 477 | dropout (float): Dropout rate. Default: 0.0.
|
475 | 478 | attention_dropout (float): Attention dropout rate. Default: 0.0.
|
476 | 479 | stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
|
@@ -508,7 +511,7 @@ def __init__(
|
508 | 511 | # Spatio-Temporal Class Positional Encoding
|
509 | 512 | self.pos_encoding = PositionalEncoding(
|
510 | 513 | embed_size=block_setting[0].input_channels,
|
511 |
| - spatial_size=tuple(input_size[1:]), |
| 514 | + spatial_size=(input_size[1], input_size[2]), |
512 | 515 | temporal_size=input_size[0],
|
513 | 516 | rel_pos=rel_pos,
|
514 | 517 | )
|
|
0 commit comments