|
1 | 1 | import math
|
2 | 2 | from functools import partial
|
3 |
| -from typing import Any, Callable, Dict, List, Optional |
| 3 | +from typing import Any, Callable, List, Optional, Union |
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 | import torch.nn.functional as F
|
@@ -143,18 +143,21 @@ def shifted_window_attention(
|
143 | 143 | Tensor[N, H, W, C]: The output tensor after shifted window attention.
|
144 | 144 | """
|
145 | 145 | B, H, W, C = input.shape
|
| 146 | + |
| 147 | + # If window size is larger than feature size, there is no need to shift window |
| 148 | + if window_size[0] >= H: |
| 149 | + shift_size[0] = 0 |
| 150 | + window_size[0] = H |
| 151 | + if window_size[1] >= W: |
| 152 | + shift_size[1] = 0 |
| 153 | + window_size[1] = W |
| 154 | + |
146 | 155 | # pad feature maps to multiples of window size
|
147 | 156 | pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
|
148 | 157 | pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
|
149 | 158 | x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
|
150 | 159 | _, pad_H, pad_W, _ = x.shape
|
151 | 160 |
|
152 |
| - # If window size is larger than feature size, there is no need to shift window |
153 |
| - if window_size[0] >= pad_H: |
154 |
| - shift_size[0] = 0 |
155 |
| - if window_size[1] >= pad_W: |
156 |
| - shift_size[1] = 0 |
157 |
| - |
158 | 161 | # cyclic shift
|
159 | 162 | if sum(shift_size) > 0:
|
160 | 163 | x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
@@ -479,7 +482,7 @@ class SwinTransformer(nn.Module):
|
479 | 482 | embed_dim (int): Patch embedding dimension.
|
480 | 483 | depths (List(int)): Depth of each Swin Transformer layer.
|
481 | 484 | num_heads (List(int)): Number of attention heads in different layers.
|
482 |
| - window_size (List[int]): Window size. |
| 485 | + window_size (int, List[int]): Window size. |
483 | 486 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
484 | 487 | stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
|
485 | 488 | num_classes (int): Number of classes for classification head. Default: 1000.
|
|
0 commit comments