39
39
# --------------------------------------------------------'
40
40
41
41
import math
42
- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
42
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
43
43
44
44
import torch
45
45
import torch .nn as nn
46
46
import torch .nn .functional as F
47
47
48
48
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
49
- from timm .layers import PatchEmbed , Mlp , SwiGLU , LayerNorm , DropPath , calculate_drop_path_rates , trunc_normal_ , use_fused_attn
50
- from timm .layers import resample_patch_embed , resample_abs_pos_embed , resize_rel_pos_bias_table , ndgrid
49
+ from timm .layers import (
50
+ PatchEmbed ,
51
+ Mlp ,
52
+ SwiGLU ,
53
+ LayerNorm ,
54
+ DropPath ,
55
+ calculate_drop_path_rates ,
56
+ trunc_normal_ ,
57
+ use_fused_attn ,
58
+ resample_patch_embed ,
59
+ resample_abs_pos_embed ,
60
+ resize_rel_pos_bias_table ,
61
+ ndgrid ,
62
+ )
51
63
52
64
from ._builder import build_model_with_cfg
53
65
from ._features import feature_take_indices
57
69
__all__ = ['Beit' ]
58
70
59
71
60
- def gen_relative_position_index (window_size : Tuple [int , int ]) -> torch .Tensor :
72
+ def gen_relative_position_index (window_size : Tuple [int , int ], device = None ) -> torch .Tensor :
61
73
"""Generate relative position index for window-based attention.
62
74
63
75
Creates a lookup table for relative position indices between all pairs of positions
@@ -74,14 +86,17 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
74
86
# cls to token & token 2 cls & cls to cls
75
87
# get pair-wise relative position index for each token inside the window
76
88
window_area = window_size [0 ] * window_size [1 ]
77
- coords = torch .stack (ndgrid (torch .arange (window_size [0 ]), torch .arange (window_size [1 ]))) # 2, Wh, Ww
89
+ coords = torch .stack (ndgrid (
90
+ torch .arange (window_size [0 ], device = device , dtype = torch .long ),
91
+ torch .arange (window_size [1 ], device = device , dtype = torch .long ),
92
+ )) # 2, Wh, Ww
78
93
coords_flatten = torch .flatten (coords , 1 ) # 2, Wh*Ww
79
94
relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :] # 2, Wh*Ww, Wh*Ww
80
95
relative_coords = relative_coords .permute (1 , 2 , 0 ).contiguous () # Wh*Ww, Wh*Ww, 2
81
96
relative_coords [:, :, 0 ] += window_size [0 ] - 1 # shift to start from 0
82
97
relative_coords [:, :, 1 ] += window_size [1 ] - 1
83
98
relative_coords [:, :, 0 ] *= 2 * window_size [1 ] - 1
84
- relative_position_index = torch .zeros (size = (window_area + 1 ,) * 2 , dtype = relative_coords .dtype )
99
+ relative_position_index = torch .zeros (size = (window_area + 1 ,) * 2 , device = device , dtype = relative_coords .dtype )
85
100
relative_position_index [1 :, 1 :] = relative_coords .sum (- 1 ) # Wh*Ww, Wh*Ww
86
101
relative_position_index [0 , 0 :] = num_relative_distance - 3
87
102
relative_position_index [0 :, 0 ] = num_relative_distance - 2
@@ -107,6 +122,8 @@ def __init__(
107
122
proj_drop : float = 0. ,
108
123
window_size : Optional [Tuple [int , int ]] = None ,
109
124
attn_head_dim : Optional [int ] = None ,
125
+ device = None ,
126
+ dtype = None ,
110
127
):
111
128
"""Initialize attention module.
112
129
@@ -120,6 +137,7 @@ def __init__(
120
137
window_size: Window size for relative position bias. If None, no relative position bias.
121
138
attn_head_dim: Dimension per attention head. If None, uses dim // num_heads.
122
139
"""
140
+ dd = {'device' : device , 'dtype' : dtype }
123
141
super ().__init__ ()
124
142
self .num_heads = num_heads
125
143
head_dim = dim // num_heads
@@ -130,11 +148,11 @@ def __init__(
130
148
self .fused_attn = use_fused_attn ()
131
149
self .qkv_bias_separate = qkv_bias_separate
132
150
133
- self .qkv = nn .Linear (dim , all_head_dim * 3 , bias = False )
151
+ self .qkv = nn .Linear (dim , all_head_dim * 3 , bias = False , ** dd )
134
152
if qkv_bias :
135
- self .q_bias = nn .Parameter (torch .zeros (all_head_dim ))
136
- self .register_buffer ('k_bias' , torch .zeros (all_head_dim ), persistent = False )
137
- self .v_bias = nn .Parameter (torch .zeros (all_head_dim ))
153
+ self .q_bias = nn .Parameter (torch .zeros (all_head_dim , ** dd ))
154
+ self .register_buffer ('k_bias' , torch .zeros (all_head_dim , ** dd ), persistent = False )
155
+ self .v_bias = nn .Parameter (torch .zeros (all_head_dim , ** dd ))
138
156
else :
139
157
self .q_bias = None
140
158
self .k_bias = None
@@ -144,15 +162,19 @@ def __init__(
144
162
self .window_size = window_size
145
163
self .num_relative_distance = (2 * window_size [0 ] - 1 ) * (2 * window_size [1 ] - 1 ) + 3
146
164
self .relative_position_bias_table = nn .Parameter (
147
- torch .zeros (self .num_relative_distance , num_heads )) # 2*Wh-1 * 2*Ww-1, nH
148
- self .register_buffer ("relative_position_index" , gen_relative_position_index (window_size ), persistent = False )
165
+ torch .zeros (self .num_relative_distance , num_heads , ** dd )) # 2*Wh-1 * 2*Ww-1, nH
166
+ self .register_buffer (
167
+ "relative_position_index" ,
168
+ gen_relative_position_index (window_size , device = device ),
169
+ persistent = False ,
170
+ )
149
171
else :
150
172
self .window_size = None
151
173
self .relative_position_bias_table = None
152
174
self .relative_position_index = None
153
175
154
176
self .attn_drop = nn .Dropout (attn_drop )
155
- self .proj = nn .Linear (all_head_dim , dim )
177
+ self .proj = nn .Linear (all_head_dim , dim , ** dd )
156
178
self .proj_drop = nn .Dropout (proj_drop )
157
179
158
180
def _get_rel_pos_bias (self ) -> torch .Tensor :
@@ -245,10 +267,12 @@ def __init__(
245
267
attn_drop : float = 0. ,
246
268
drop_path : float = 0. ,
247
269
init_values : Optional [float ] = None ,
248
- act_layer : Callable = nn .GELU ,
249
- norm_layer : Callable = LayerNorm ,
270
+ act_layer : Type [ nn . Module ] = nn .GELU ,
271
+ norm_layer : Type [ nn . Module ] = LayerNorm ,
250
272
window_size : Optional [Tuple [int , int ]] = None ,
251
273
attn_head_dim : Optional [int ] = None ,
274
+ device = None ,
275
+ dtype = None ,
252
276
):
253
277
"""Initialize transformer block.
254
278
@@ -268,8 +292,9 @@ def __init__(
268
292
window_size: Window size for relative position bias in attention.
269
293
attn_head_dim: Dimension per attention head.
270
294
"""
295
+ dd = {'device' : device , 'dtype' : dtype }
271
296
super ().__init__ ()
272
- self .norm1 = norm_layer (dim )
297
+ self .norm1 = norm_layer (dim , ** dd )
273
298
self .attn = Attention (
274
299
dim ,
275
300
num_heads = num_heads ,
@@ -278,17 +303,19 @@ def __init__(
278
303
proj_drop = proj_drop ,
279
304
window_size = window_size ,
280
305
attn_head_dim = attn_head_dim ,
306
+ ** dd ,
281
307
)
282
308
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
283
309
self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
284
310
285
- self .norm2 = norm_layer (dim )
311
+ self .norm2 = norm_layer (dim , ** dd )
286
312
if swiglu_mlp :
287
313
self .mlp = SwiGLU (
288
314
in_features = dim ,
289
315
hidden_features = int (dim * mlp_ratio ),
290
316
norm_layer = norm_layer if scale_mlp else None ,
291
317
drop = proj_drop ,
318
+ ** dd ,
292
319
)
293
320
else :
294
321
self .mlp = Mlp (
@@ -297,12 +324,13 @@ def __init__(
297
324
act_layer = act_layer ,
298
325
norm_layer = norm_layer if scale_mlp else None ,
299
326
drop = proj_drop ,
327
+ ** dd ,
300
328
)
301
329
self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
302
330
303
331
if init_values :
304
- self .gamma_1 = nn .Parameter (init_values * torch .ones (dim ))
305
- self .gamma_2 = nn .Parameter (init_values * torch .ones (dim ))
332
+ self .gamma_1 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
333
+ self .gamma_2 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
306
334
else :
307
335
self .gamma_1 , self .gamma_2 = None , None
308
336
@@ -332,18 +360,19 @@ class RelativePositionBias(nn.Module):
332
360
within a window, including special handling for cls token.
333
361
"""
334
362
335
- def __init__ (self , window_size : Tuple [int , int ], num_heads : int ):
363
+ def __init__ (self , window_size : Tuple [int , int ], num_heads : int , device = None , dtype = None ):
336
364
"""Initialize relative position bias module.
337
365
338
366
Args:
339
367
window_size: Height and width of the attention window.
340
368
num_heads: Number of attention heads.
341
369
"""
370
+ dd = {'device' : device , 'dtype' : dtype }
342
371
super ().__init__ ()
343
372
self .window_size = window_size
344
373
self .window_area = window_size [0 ] * window_size [1 ]
345
374
num_relative_distance = (2 * window_size [0 ] - 1 ) * (2 * window_size [1 ] - 1 ) + 3
346
- self .relative_position_bias_table = nn .Parameter (torch .zeros (num_relative_distance , num_heads ))
375
+ self .relative_position_bias_table = nn .Parameter (torch .zeros (num_relative_distance , num_heads , ** dd ))
347
376
# trunc_normal_(self.relative_position_bias_table, std=.02)
348
377
self .register_buffer ("relative_position_index" , gen_relative_position_index (window_size ))
349
378
@@ -385,12 +414,14 @@ def __init__(
385
414
proj_drop_rate : float = 0. ,
386
415
attn_drop_rate : float = 0. ,
387
416
drop_path_rate : float = 0. ,
388
- norm_layer : Callable = LayerNorm ,
417
+ norm_layer : Type [ nn . Module ] = LayerNorm ,
389
418
init_values : Optional [float ] = None ,
390
419
use_abs_pos_emb : bool = True ,
391
420
use_rel_pos_bias : bool = False ,
392
421
use_shared_rel_pos_bias : bool = False ,
393
422
head_init_scale : float = 0.001 ,
423
+ device = None ,
424
+ dtype = None ,
394
425
):
395
426
"""Initialize BEiT model.
396
427
@@ -419,6 +450,7 @@ def __init__(
419
450
use_shared_rel_pos_bias: If True, share relative position bias across layers.
420
451
head_init_scale: Scale factor for head initialization.
421
452
"""
453
+ dd = {'device' : device , 'dtype' : dtype }
422
454
super ().__init__ ()
423
455
self .num_classes = num_classes
424
456
self .global_pool = global_pool
@@ -431,19 +463,21 @@ def __init__(
431
463
patch_size = patch_size ,
432
464
in_chans = in_chans ,
433
465
embed_dim = embed_dim ,
466
+ ** dd ,
434
467
)
435
468
num_patches = self .patch_embed .num_patches
436
469
r = self .patch_embed .feat_ratio () if hasattr (self .patch_embed , 'feat_ratio' ) else patch_size
437
470
438
- self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim ))
471
+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim , ** dd ))
439
472
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
440
- self .pos_embed = nn .Parameter (torch .zeros (1 , num_patches + 1 , embed_dim )) if use_abs_pos_emb else None
473
+ self .pos_embed = nn .Parameter (torch .zeros (1 , num_patches + 1 , embed_dim , ** dd )) if use_abs_pos_emb else None
441
474
self .pos_drop = nn .Dropout (p = pos_drop_rate )
442
475
443
476
if use_shared_rel_pos_bias :
444
477
self .rel_pos_bias = RelativePositionBias (
445
478
window_size = self .patch_embed .grid_size ,
446
479
num_heads = num_heads ,
480
+ ** dd ,
447
481
)
448
482
else :
449
483
self .rel_pos_bias = None
@@ -463,16 +497,17 @@ def __init__(
463
497
norm_layer = norm_layer ,
464
498
init_values = init_values ,
465
499
window_size = self .patch_embed .grid_size if use_rel_pos_bias else None ,
500
+ ** dd ,
466
501
)
467
502
for i in range (depth )])
468
503
self .feature_info = [
469
504
dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = r ) for i in range (depth )]
470
505
471
506
use_fc_norm = self .global_pool == 'avg'
472
- self .norm = nn .Identity () if use_fc_norm else norm_layer (embed_dim )
473
- self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
507
+ self .norm = nn .Identity () if use_fc_norm else norm_layer (embed_dim , ** dd )
508
+ self .fc_norm = norm_layer (embed_dim , ** dd ) if use_fc_norm else nn .Identity ()
474
509
self .head_drop = nn .Dropout (drop_rate )
475
- self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
510
+ self .head = nn .Linear (embed_dim , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
476
511
477
512
self .apply (self ._init_weights )
478
513
if self .pos_embed is not None :
0 commit comments