File tree 1 file changed +5
-4
lines changed
1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -193,19 +193,20 @@ def __init__(
193
193
relative_coords [:, :, 1 ] += self .window_size - 1
194
194
relative_coords [:, :, 0 ] *= 2 * self .window_size - 1
195
195
relative_position_index = relative_coords .sum (- 1 ).view (- 1 ) # Wh*Ww*Wh*Ww
196
-
196
+
197
197
# define a parameter table of relative position bias
198
- relative_position_bias_table = torch .zeros ((2 * window_size - 1 ) * (2 * window_size - 1 ), num_heads ) # 2*Wh-1 * 2*Ww-1, nH
198
+ relative_position_bias_table = torch .zeros (
199
+ (2 * window_size - 1 ) * (2 * window_size - 1 ), num_heads
200
+ ) # 2*Wh-1 * 2*Ww-1, nH
199
201
nn .init .trunc_normal_ (relative_position_bias_table , std = 0.02 )
200
-
202
+
201
203
relative_position_bias = relative_position_bias_table [relative_position_index ] # type: ignore[index]
202
204
relative_position_bias = relative_position_bias .view (
203
205
self .window_size * self .window_size , self .window_size * self .window_size , - 1
204
206
)
205
207
self .relative_position_bias = nn .Parameter (relative_position_bias .permute (2 , 0 , 1 ).contiguous ().unsqueeze (0 ))
206
208
207
209
def forward (self , x : Tensor ):
208
-
209
210
210
211
return shifted_window_attention (
211
212
x ,
You can’t perform that action at this time.
0 commit comments