Skip to content

Commit 4027ebc

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] [ViT] Refactor forward function (#5172)
Summary: * refactor forward function * reemove n from return Reviewed By: sallysyw Differential Revision: D33479279 fbshipit-source-id: a228ee12d9c1c0fbf340a8f0af5d2d7c0bc52bfa
1 parent 48895ce commit 4027ebc

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torchvision/prototype/models/vision_transformer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _init_weights(self):
202202
nn.init.zeros_(self.heads.head.weight)
203203
nn.init.zeros_(self.heads.head.bias)
204204

205-
def forward(self, x: torch.Tensor):
205+
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
206206
n, c, h, w = x.shape
207207
p = self.patch_size
208208
torch._assert(h == self.image_size, "Wrong image height!")
@@ -221,7 +221,14 @@ def forward(self, x: torch.Tensor):
221221
# embedding dimension
222222
x = x.permute(0, 2, 1)
223223

224-
# Expand the class token to the full batch.
224+
return x
225+
226+
def forward(self, x: torch.Tensor):
227+
# Reshaping and permuting the input tensor
228+
x = self._process_input(x)
229+
n = x.shape[0]
230+
231+
# Expand the class token to the full batch
225232
batch_class_token = self.class_token.expand(n, -1, -1)
226233
x = torch.cat([batch_class_token, x], dim=1)
227234

0 commit comments

Comments
 (0)