Skip to content

Commit 1c7579c

Browse files
committed
fix dynamic padding edge case
1 parent f4ea5df commit 1c7579c

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

torchvision/models/swin_transformer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from functools import partial
3-
from typing import Any, Callable, Dict, List, Optional
3+
from typing import Any, Callable, List, Optional, Union
44

55
import torch
66
import torch.nn.functional as F
@@ -143,18 +143,21 @@ def shifted_window_attention(
143143
Tensor[N, H, W, C]: The output tensor after shifted window attention.
144144
"""
145145
B, H, W, C = input.shape
146+
147+
# If window size is larger than feature size, there is no need to shift window
148+
if window_size[0] >= H:
149+
shift_size[0] = 0
150+
window_size[0] = H
151+
if window_size[1] >= W:
152+
shift_size[1] = 0
153+
window_size[1] = W
154+
146155
# pad feature maps to multiples of window size
147156
pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
148157
pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
149158
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
150159
_, pad_H, pad_W, _ = x.shape
151160

152-
# If window size is larger than feature size, there is no need to shift window
153-
if window_size[0] >= pad_H:
154-
shift_size[0] = 0
155-
if window_size[1] >= pad_W:
156-
shift_size[1] = 0
157-
158161
# cyclic shift
159162
if sum(shift_size) > 0:
160163
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
@@ -479,7 +482,7 @@ class SwinTransformer(nn.Module):
479482
embed_dim (int): Patch embedding dimension.
480483
depths (List(int)): Depth of each Swin Transformer layer.
481484
num_heads (List(int)): Number of attention heads in different layers.
482-
window_size (List[int]): Window size.
485+
window_size (int, List[int]): Window size.
483486
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
484487
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
485488
num_classes (int): Number of classes for classification head. Default: 1000.

0 commit comments

Comments
 (0)