Skip to content

Commit a7dc50f

Browse files
committed
A whack of classic convnets converted with dd factory kwargs. densenet, dla, dpn, hrnet, inception_next, inception_resnet_v2, inception_v3/v4, senet, tresnet, vgg, vovnet, xception, xception_aligned
1 parent 1e172a0 commit a7dc50f

16 files changed

+1490
-797
lines changed

timm/layers/separable_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
dtype=None,
3434
):
3535
dd = {'device': device, 'dtype': dtype}
36-
super(SeparableConvNormAct, self).__init__()
36+
super().__init__()
3737

3838
self.conv_dw = create_conv2d(
3939
in_channels,
@@ -57,7 +57,7 @@ def __init__(
5757

5858
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
5959
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
60-
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
60+
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs, **dd)
6161

6262
@property
6363
def in_channels(self):

timm/models/densenet.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
import re
66
from collections import OrderedDict
7-
from typing import Any, Dict, Optional, Tuple, Union
7+
from typing import Any, Dict, Optional, Tuple, Type, Union
88

99
import torch
1010
import torch.nn as nn
@@ -31,9 +31,11 @@ def __init__(
3131
num_input_features: int,
3232
growth_rate: int,
3333
bn_size: int,
34-
norm_layer: type = BatchNormAct2d,
34+
norm_layer: Type[nn.Module] = BatchNormAct2d,
3535
drop_rate: float = 0.,
3636
grad_checkpointing: bool = False,
37+
device=None,
38+
dtype=None,
3739
) -> None:
3840
"""Initialize DenseLayer.
3941
@@ -45,13 +47,14 @@ def __init__(
4547
drop_rate: Dropout rate.
4648
grad_checkpointing: Use gradient checkpointing.
4749
"""
50+
dd = {'device': device, 'dtype': dtype}
4851
super(DenseLayer, self).__init__()
49-
self.add_module('norm1', norm_layer(num_input_features)),
52+
self.add_module('norm1', norm_layer(num_input_features, **dd)),
5053
self.add_module('conv1', nn.Conv2d(
51-
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
52-
self.add_module('norm2', norm_layer(bn_size * growth_rate)),
54+
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False, **dd)),
55+
self.add_module('norm2', norm_layer(bn_size * growth_rate, **dd)),
5356
self.add_module('conv2', nn.Conv2d(
54-
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
57+
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False, **dd)),
5558
self.drop_rate = float(drop_rate)
5659
self.grad_checkpointing = grad_checkpointing
5760

@@ -129,9 +132,11 @@ def __init__(
129132
num_input_features: int,
130133
bn_size: int,
131134
growth_rate: int,
132-
norm_layer: type = BatchNormAct2d,
135+
norm_layer: Type[nn.Module] = BatchNormAct2d,
133136
drop_rate: float = 0.,
134137
grad_checkpointing: bool = False,
138+
device=None,
139+
dtype=None,
135140
) -> None:
136141
"""Initialize DenseBlock.
137142
@@ -144,6 +149,7 @@ def __init__(
144149
drop_rate: Dropout rate.
145150
grad_checkpointing: Use gradient checkpointing.
146151
"""
152+
dd = {'device': device, 'dtype': dtype}
147153
super(DenseBlock, self).__init__()
148154
for i in range(num_layers):
149155
layer = DenseLayer(
@@ -153,6 +159,7 @@ def __init__(
153159
norm_layer=norm_layer,
154160
drop_rate=drop_rate,
155161
grad_checkpointing=grad_checkpointing,
162+
**dd,
156163
)
157164
self.add_module('denselayer%d' % (i + 1), layer)
158165

@@ -182,8 +189,10 @@ def __init__(
182189
self,
183190
num_input_features: int,
184191
num_output_features: int,
185-
norm_layer: type = BatchNormAct2d,
186-
aa_layer: Optional[type] = None,
192+
norm_layer: Type[nn.Module] = BatchNormAct2d,
193+
aa_layer: Optional[Type[nn.Module]] = None,
194+
device=None,
195+
dtype=None,
187196
) -> None:
188197
"""Initialize DenseTransition.
189198
@@ -193,12 +202,13 @@ def __init__(
193202
norm_layer: Normalization layer class.
194203
aa_layer: Anti-aliasing layer class.
195204
"""
205+
dd = {'device': device, 'dtype': dtype}
196206
super(DenseTransition, self).__init__()
197-
self.add_module('norm', norm_layer(num_input_features))
207+
self.add_module('norm', norm_layer(num_input_features, **dd))
198208
self.add_module('conv', nn.Conv2d(
199-
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
209+
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False, **dd))
200210
if aa_layer is not None:
201-
self.add_module('pool', aa_layer(num_output_features, stride=2))
211+
self.add_module('pool', aa_layer(num_output_features, stride=2, **dd))
202212
else:
203213
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
204214

@@ -231,11 +241,13 @@ def __init__(
231241
stem_type: str = '',
232242
act_layer: str = 'relu',
233243
norm_layer: str = 'batchnorm2d',
234-
aa_layer: Optional[type] = None,
244+
aa_layer: Optional[Type[nn.Module]] = None,
235245
drop_rate: float = 0.,
236246
proj_drop_rate: float = 0.,
237247
memory_efficient: bool = False,
238248
aa_stem_only: bool = True,
249+
device=None,
250+
dtype=None,
239251
) -> None:
240252
"""Initialize DenseNet.
241253
@@ -255,6 +267,7 @@ def __init__(
255267
memory_efficient: If True, uses checkpointing for memory efficiency.
256268
aa_stem_only: Apply anti-aliasing only to stem.
257269
"""
270+
dd = {'device': device, 'dtype': dtype}
258271
self.num_classes = num_classes
259272
super(DenseNet, self).__init__()
260273
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
@@ -267,25 +280,25 @@ def __init__(
267280
else:
268281
stem_pool = nn.Sequential(*[
269282
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
270-
aa_layer(channels=num_init_features, stride=2)])
283+
aa_layer(channels=num_init_features, stride=2, **dd)])
271284
if deep_stem:
272285
stem_chs_1 = stem_chs_2 = growth_rate
273286
if 'tiered' in stem_type:
274287
stem_chs_1 = 3 * (growth_rate // 4)
275288
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
276289
self.features = nn.Sequential(OrderedDict([
277-
('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
278-
('norm0', norm_layer(stem_chs_1)),
279-
('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
280-
('norm1', norm_layer(stem_chs_2)),
281-
('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
282-
('norm2', norm_layer(num_init_features)),
290+
('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False, **dd)),
291+
('norm0', norm_layer(stem_chs_1, **dd)),
292+
('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False, **dd)),
293+
('norm1', norm_layer(stem_chs_2, **dd)),
294+
('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False, **dd)),
295+
('norm2', norm_layer(num_init_features, **dd)),
283296
('pool0', stem_pool),
284297
]))
285298
else:
286299
self.features = nn.Sequential(OrderedDict([
287-
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
288-
('norm0', norm_layer(num_init_features)),
300+
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False, **dd)),
301+
('norm0', norm_layer(num_init_features, **dd)),
289302
('pool0', stem_pool),
290303
]))
291304
self.feature_info = [
@@ -303,6 +316,7 @@ def __init__(
303316
norm_layer=norm_layer,
304317
drop_rate=proj_drop_rate,
305318
grad_checkpointing=memory_efficient,
319+
**dd,
306320
)
307321
module_name = f'denseblock{(i + 1)}'
308322
self.features.add_module(module_name, block)
@@ -317,12 +331,13 @@ def __init__(
317331
num_output_features=num_features // 2,
318332
norm_layer=norm_layer,
319333
aa_layer=transition_aa_layer,
334+
**dd,
320335
)
321336
self.features.add_module(f'transition{i + 1}', trans)
322337
num_features = num_features // 2
323338

324339
# Final batch norm
325-
self.features.add_module('norm5', norm_layer(num_features))
340+
self.features.add_module('norm5', norm_layer(num_features, **dd))
326341

327342
self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')]
328343
self.num_features = self.head_hidden_size = num_features
@@ -332,6 +347,7 @@ def __init__(
332347
self.num_features,
333348
self.num_classes,
334349
pool_type=global_pool,
350+
**dd,
335351
)
336352
self.global_pool = global_pool
337353
self.head_drop = nn.Dropout(drop_rate)

0 commit comments

Comments
 (0)