Skip to content

Commit 4c35b78

Browse files
committed
Add device/dtype factory kwargs to beit, efficientformer*, efficientvit*, focalnet, levit. Fix ByobNet
1 parent f15f7c9 commit 4c35b78

File tree

9 files changed

+1078
-654
lines changed

9 files changed

+1078
-654
lines changed

timm/models/beit.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,27 @@
3939
# --------------------------------------------------------'
4040

4141
import math
42-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
42+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
4343

4444
import torch
4545
import torch.nn as nn
4646
import torch.nn.functional as F
4747

4848
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
49-
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, calculate_drop_path_rates, trunc_normal_, use_fused_attn
50-
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
49+
from timm.layers import (
50+
PatchEmbed,
51+
Mlp,
52+
SwiGLU,
53+
LayerNorm,
54+
DropPath,
55+
calculate_drop_path_rates,
56+
trunc_normal_,
57+
use_fused_attn,
58+
resample_patch_embed,
59+
resample_abs_pos_embed,
60+
resize_rel_pos_bias_table,
61+
ndgrid,
62+
)
5163

5264
from ._builder import build_model_with_cfg
5365
from ._features import feature_take_indices
@@ -57,7 +69,7 @@
5769
__all__ = ['Beit']
5870

5971

60-
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
72+
def gen_relative_position_index(window_size: Tuple[int, int], device=None) -> torch.Tensor:
6173
"""Generate relative position index for window-based attention.
6274
6375
Creates a lookup table for relative position indices between all pairs of positions
@@ -74,14 +86,17 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
7486
# cls to token & token 2 cls & cls to cls
7587
# get pair-wise relative position index for each token inside the window
7688
window_area = window_size[0] * window_size[1]
77-
coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
89+
coords = torch.stack(ndgrid(
90+
torch.arange(window_size[0], device=device, dtype=torch.long),
91+
torch.arange(window_size[1], device=device, dtype=torch.long),
92+
)) # 2, Wh, Ww
7893
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
7994
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
8095
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
8196
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
8297
relative_coords[:, :, 1] += window_size[1] - 1
8398
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
84-
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
99+
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, device=device, dtype=relative_coords.dtype)
85100
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
86101
relative_position_index[0, 0:] = num_relative_distance - 3
87102
relative_position_index[0:, 0] = num_relative_distance - 2
@@ -107,6 +122,8 @@ def __init__(
107122
proj_drop: float = 0.,
108123
window_size: Optional[Tuple[int, int]] = None,
109124
attn_head_dim: Optional[int] = None,
125+
device=None,
126+
dtype=None,
110127
):
111128
"""Initialize attention module.
112129
@@ -120,6 +137,7 @@ def __init__(
120137
window_size: Window size for relative position bias. If None, no relative position bias.
121138
attn_head_dim: Dimension per attention head. If None, uses dim // num_heads.
122139
"""
140+
dd = {'device': device, 'dtype': dtype}
123141
super().__init__()
124142
self.num_heads = num_heads
125143
head_dim = dim // num_heads
@@ -130,11 +148,11 @@ def __init__(
130148
self.fused_attn = use_fused_attn()
131149
self.qkv_bias_separate = qkv_bias_separate
132150

133-
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
151+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False, **dd)
134152
if qkv_bias:
135-
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
136-
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
137-
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
153+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim, **dd))
154+
self.register_buffer('k_bias', torch.zeros(all_head_dim, **dd), persistent=False)
155+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim, **dd))
138156
else:
139157
self.q_bias = None
140158
self.k_bias = None
@@ -144,15 +162,19 @@ def __init__(
144162
self.window_size = window_size
145163
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
146164
self.relative_position_bias_table = nn.Parameter(
147-
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
148-
self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False)
165+
torch.zeros(self.num_relative_distance, num_heads, **dd)) # 2*Wh-1 * 2*Ww-1, nH
166+
self.register_buffer(
167+
"relative_position_index",
168+
gen_relative_position_index(window_size, device=device),
169+
persistent=False,
170+
)
149171
else:
150172
self.window_size = None
151173
self.relative_position_bias_table = None
152174
self.relative_position_index = None
153175

154176
self.attn_drop = nn.Dropout(attn_drop)
155-
self.proj = nn.Linear(all_head_dim, dim)
177+
self.proj = nn.Linear(all_head_dim, dim, **dd)
156178
self.proj_drop = nn.Dropout(proj_drop)
157179

158180
def _get_rel_pos_bias(self) -> torch.Tensor:
@@ -245,10 +267,12 @@ def __init__(
245267
attn_drop: float = 0.,
246268
drop_path: float = 0.,
247269
init_values: Optional[float] = None,
248-
act_layer: Callable = nn.GELU,
249-
norm_layer: Callable = LayerNorm,
270+
act_layer: Type[nn.Module] = nn.GELU,
271+
norm_layer: Type[nn.Module] = LayerNorm,
250272
window_size: Optional[Tuple[int, int]] = None,
251273
attn_head_dim: Optional[int] = None,
274+
device=None,
275+
dtype=None,
252276
):
253277
"""Initialize transformer block.
254278
@@ -268,8 +292,9 @@ def __init__(
268292
window_size: Window size for relative position bias in attention.
269293
attn_head_dim: Dimension per attention head.
270294
"""
295+
dd = {'device': device, 'dtype': dtype}
271296
super().__init__()
272-
self.norm1 = norm_layer(dim)
297+
self.norm1 = norm_layer(dim, **dd)
273298
self.attn = Attention(
274299
dim,
275300
num_heads=num_heads,
@@ -278,17 +303,19 @@ def __init__(
278303
proj_drop=proj_drop,
279304
window_size=window_size,
280305
attn_head_dim=attn_head_dim,
306+
**dd,
281307
)
282308
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
283309
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
284310

285-
self.norm2 = norm_layer(dim)
311+
self.norm2 = norm_layer(dim, **dd)
286312
if swiglu_mlp:
287313
self.mlp = SwiGLU(
288314
in_features=dim,
289315
hidden_features=int(dim * mlp_ratio),
290316
norm_layer=norm_layer if scale_mlp else None,
291317
drop=proj_drop,
318+
**dd,
292319
)
293320
else:
294321
self.mlp = Mlp(
@@ -297,12 +324,13 @@ def __init__(
297324
act_layer=act_layer,
298325
norm_layer=norm_layer if scale_mlp else None,
299326
drop=proj_drop,
327+
**dd,
300328
)
301329
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
302330

303331
if init_values:
304-
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
305-
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
332+
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
333+
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
306334
else:
307335
self.gamma_1, self.gamma_2 = None, None
308336

@@ -332,18 +360,19 @@ class RelativePositionBias(nn.Module):
332360
within a window, including special handling for cls token.
333361
"""
334362

335-
def __init__(self, window_size: Tuple[int, int], num_heads: int):
363+
def __init__(self, window_size: Tuple[int, int], num_heads: int, device=None, dtype=None):
336364
"""Initialize relative position bias module.
337365
338366
Args:
339367
window_size: Height and width of the attention window.
340368
num_heads: Number of attention heads.
341369
"""
370+
dd = {'device': device, 'dtype': dtype}
342371
super().__init__()
343372
self.window_size = window_size
344373
self.window_area = window_size[0] * window_size[1]
345374
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
346-
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
375+
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads, **dd))
347376
# trunc_normal_(self.relative_position_bias_table, std=.02)
348377
self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
349378

@@ -385,12 +414,14 @@ def __init__(
385414
proj_drop_rate: float = 0.,
386415
attn_drop_rate: float = 0.,
387416
drop_path_rate: float = 0.,
388-
norm_layer: Callable = LayerNorm,
417+
norm_layer: Type[nn.Module] = LayerNorm,
389418
init_values: Optional[float] = None,
390419
use_abs_pos_emb: bool = True,
391420
use_rel_pos_bias: bool = False,
392421
use_shared_rel_pos_bias: bool = False,
393422
head_init_scale: float = 0.001,
423+
device=None,
424+
dtype=None,
394425
):
395426
"""Initialize BEiT model.
396427
@@ -419,6 +450,7 @@ def __init__(
419450
use_shared_rel_pos_bias: If True, share relative position bias across layers.
420451
head_init_scale: Scale factor for head initialization.
421452
"""
453+
dd = {'device': device, 'dtype': dtype}
422454
super().__init__()
423455
self.num_classes = num_classes
424456
self.global_pool = global_pool
@@ -431,19 +463,21 @@ def __init__(
431463
patch_size=patch_size,
432464
in_chans=in_chans,
433465
embed_dim=embed_dim,
466+
**dd,
434467
)
435468
num_patches = self.patch_embed.num_patches
436469
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
437470

438-
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
471+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
439472
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
440-
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
473+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim, **dd)) if use_abs_pos_emb else None
441474
self.pos_drop = nn.Dropout(p=pos_drop_rate)
442475

443476
if use_shared_rel_pos_bias:
444477
self.rel_pos_bias = RelativePositionBias(
445478
window_size=self.patch_embed.grid_size,
446479
num_heads=num_heads,
480+
**dd,
447481
)
448482
else:
449483
self.rel_pos_bias = None
@@ -463,16 +497,17 @@ def __init__(
463497
norm_layer=norm_layer,
464498
init_values=init_values,
465499
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
500+
**dd,
466501
)
467502
for i in range(depth)])
468503
self.feature_info = [
469504
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
470505

471506
use_fc_norm = self.global_pool == 'avg'
472-
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
473-
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
507+
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim, **dd)
508+
self.fc_norm = norm_layer(embed_dim, **dd) if use_fc_norm else nn.Identity()
474509
self.head_drop = nn.Dropout(drop_rate)
475-
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
510+
self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
476511

477512
self.apply(self._init_weights)
478513
if self.pos_embed is not None:

timm/models/byobnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def __init__(
478478
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False, **dd)
479479
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs, **dd)
480480
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
481-
self.act = nn.Identity() if linear_out else layers.act(inplace=True, **dd)
481+
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
482482

483483
def init_weights(self, zero_init_last: bool = False):
484484
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:

timm/models/edgenext.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,17 @@
1616
from torch import nn
1717

1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19-
from timm.layers import trunc_normal_tf_, DropPath, calculate_drop_path_rates, LayerNorm2d, Mlp, create_conv2d, \
20-
NormMlpClassifierHead, ClassifierHead
19+
from timm.layers import (
20+
DropPath,
21+
calculate_drop_path_rates,
22+
LayerNorm2d,
23+
Mlp,
24+
create_conv2d,
25+
NormMlpClassifierHead,
26+
ClassifierHead,
27+
trunc_normal_tf_,
28+
)
29+
2130
from ._builder import build_model_with_cfg
2231
from ._features import feature_take_indices
2332
from ._features_fx import register_notrace_module

0 commit comments

Comments
 (0)