File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
torchvision/prototype/models Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -202,7 +202,7 @@ def _init_weights(self):
202
202
nn .init .zeros_ (self .heads .head .weight )
203
203
nn .init .zeros_ (self .heads .head .bias )
204
204
205
- def forward (self , x : torch .Tensor ):
205
+ def _process_input (self , x : torch .Tensor ) -> torch . Tensor :
206
206
n , c , h , w = x .shape
207
207
p = self .patch_size
208
208
torch ._assert (h == self .image_size , "Wrong image height!" )
@@ -221,7 +221,14 @@ def forward(self, x: torch.Tensor):
221
221
# embedding dimension
222
222
x = x .permute (0 , 2 , 1 )
223
223
224
- # Expand the class token to the full batch.
224
+ return x
225
+
226
+ def forward (self , x : torch .Tensor ):
227
+ # Reshaping and permuting the input tensor
228
+ x = self ._process_input (x )
229
+ n = x .shape [0 ]
230
+
231
+ # Expand the class token to the full batch
225
232
batch_class_token = self .class_token .expand (n , - 1 , - 1 )
226
233
x = torch .cat ([batch_class_token , x ], dim = 1 )
227
234
You can’t perform that action at this time.
0 commit comments