Skip to content

Commit 9dea7ff

Browse files
YosuaMichaeldatumbox
authored andcommitted
[fbsync] Update vision_transformer.py (#5820)
Summary: the assert msg should be same Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095680 fbshipit-source-id: 590b6befa239e06b077778886c058da1c13be550 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 374ace8 commit 9dea7ff

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchvision/models/vision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
8080

8181
def forward(self, input: torch.Tensor):
82-
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
82+
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
8383
x = self.ln_1(input)
8484
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
8585
x = self.dropout(x)

0 commit comments

Comments
 (0)