Skip to content

Commit 41faba2

Browse files
committed
fix formatting
1 parent f061544 commit 41faba2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torchvision/models/swin_transformer.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,20 @@ def __init__(
193193
relative_coords[:, :, 1] += self.window_size - 1
194194
relative_coords[:, :, 0] *= 2 * self.window_size - 1
195195
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
196-
196+
197197
# define a parameter table of relative position bias
198-
relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH
198+
relative_position_bias_table = torch.zeros(
199+
(2 * window_size - 1) * (2 * window_size - 1), num_heads
200+
) # 2*Wh-1 * 2*Ww-1, nH
199201
nn.init.trunc_normal_(relative_position_bias_table, std=0.02)
200-
202+
201203
relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
202204
relative_position_bias = relative_position_bias.view(
203205
self.window_size * self.window_size, self.window_size * self.window_size, -1
204206
)
205207
self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0))
206208

207209
def forward(self, x: Tensor):
208-
209210

210211
return shifted_window_attention(
211212
x,

0 commit comments

Comments
 (0)