Skip to content

[Cherrypick] Refactor swin transfomer so its component can be reused on the 3d version #6100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 83 additions & 63 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,23 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm
self.norm = norm_layer(4 * dim)

def forward(self, x: Tensor):
B, H, W, C = x.shape

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
"""
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C

x = self.norm(x)
x = self.reduction(x)
x = x.view(B, H // 2, W // 2, 2 * C)
x = self.reduction(x) # ... H/2 W/2 2*C
return x


Expand All @@ -59,9 +64,9 @@ def shifted_window_attention(
qkv_weight: Tensor,
proj_weight: Tensor,
relative_position_bias: Tensor,
window_size: int,
window_size: List[int],
num_heads: int,
shift_size: int = 0,
shift_size: List[int],
attention_dropout: float = 0.0,
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
Expand All @@ -75,9 +80,9 @@ def shifted_window_attention(
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (int): Window size.
window_size (List[int]): Window size.
num_heads (int): Number of attention heads.
shift_size (int): Shift size for shifted window attention. Default: 0.
shift_size (List[int]): Shift size for shifted window attention.
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
Expand All @@ -87,23 +92,25 @@ def shifted_window_attention(
"""
B, H, W, C = input.shape
# pad feature maps to multiples of window size
pad_r = (window_size - W % window_size) % window_size
pad_b = (window_size - H % window_size) % window_size
pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
_, pad_H, pad_W, _ = x.shape

# If window size is larger than feature size, there is no need to shift window.
if window_size == min(pad_H, pad_W):
shift_size = 0
# If window size is larger than feature size, there is no need to shift window
if window_size[0] >= pad_H:
shift_size[0] = 0
if window_size[1] >= pad_W:
shift_size[1] = 0

# cyclic shift
if shift_size > 0:
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))

# partition windows
num_windows = (pad_H // window_size) * (pad_W // window_size)
x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C
num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C

# multi-head attention
qkv = F.linear(x, qkv_weight, qkv_bias)
Expand All @@ -114,17 +121,18 @@ def shifted_window_attention(
# add relative position bias
attn = attn + relative_position_bias

if shift_size > 0:
if sum(shift_size) > 0:
# generate attention mask
attn_mask = x.new_zeros((pad_H, pad_W))
slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None))
h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
count = 0
for h in slices:
for w in slices:
for h in h_slices:
for w in w_slices:
attn_mask[h[0] : h[1], w[0] : w[1]] = count
count += 1
attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size)
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size)
attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
Expand All @@ -139,12 +147,12 @@ def shifted_window_attention(
x = F.dropout(x, p=dropout)

# reverse windows
x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C)
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)

# reverse cyclic shift
if shift_size > 0:
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))

# unpad features
x = x[:, :H, :W, :].contiguous()
Expand All @@ -162,15 +170,17 @@ class ShiftedWindowAttention(nn.Module):
def __init__(
self,
dim: int,
window_size: int,
shift_size: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__()
if len(window_size) != 2 or len(shift_size) != 2:
raise ValueError("window_size and shift_size must be of length 2")
self.window_size = window_size
self.shift_size = shift_size
self.num_heads = num_heads
Expand All @@ -182,29 +192,35 @@ def __init__(

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""

N = self.window_size[0] * self.window_size[1]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(
self.window_size * self.window_size, self.window_size * self.window_size, -1
)
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)

return shifted_window_attention(
Expand All @@ -228,31 +244,33 @@ class SwinTransformerBlock(nn.Module):
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size. Default: 7.
shift_size (int): Shift size for shifted window attention. Default: 0.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
"""

def __init__(
self,
dim: int,
num_heads: int,
window_size: int = 7,
shift_size: int = 0,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
):
super().__init__()

self.norm1 = norm_layer(dim)
self.attn = ShiftedWindowAttention(
self.attn = attn_layer(
dim,
window_size,
shift_size,
Expand Down Expand Up @@ -281,11 +299,11 @@ class SwinTransformer(nn.Module):
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
Args:
patch_size (int): Patch size.
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7.
window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
Expand All @@ -297,11 +315,11 @@ class SwinTransformer(nn.Module):

def __init__(
self,
patch_size: int,
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: int = 7,
window_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
Expand All @@ -324,7 +342,9 @@ def __init__(
# split image into non-overlapping patches
layers.append(
nn.Sequential(
nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size),
nn.Conv2d(
3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
),
Permute([0, 2, 3, 1]),
norm_layer(embed_dim),
)
Expand All @@ -344,7 +364,7 @@ def __init__(
dim,
num_heads[i_stage],
window_size=window_size,
shift_size=0 if i_layer % 2 == 0 else window_size // 2,
shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
Expand Down Expand Up @@ -381,11 +401,11 @@ def forward(self, x):


def _swin_transformer(
patch_size: int,
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: int,
window_size: List[int],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
Expand Down Expand Up @@ -508,11 +528,11 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
weights = Swin_T_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
window_size=[7, 7],
stochastic_depth_prob=0.2,
weights=weights,
progress=progress,
Expand Down Expand Up @@ -544,11 +564,11 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
weights = Swin_S_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
window_size=[7, 7],
stochastic_depth_prob=0.3,
weights=weights,
progress=progress,
Expand Down Expand Up @@ -580,11 +600,11 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
weights = Swin_B_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
patch_size=[4, 4],
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=7,
window_size=[7, 7],
stochastic_depth_prob=0.5,
weights=weights,
progress=progress,
Expand Down