4
4
"""
5
5
import re
6
6
from collections import OrderedDict
7
- from typing import Any , Dict , Optional , Tuple , Union
7
+ from typing import Any , Dict , Optional , Tuple , Type , Union
8
8
9
9
import torch
10
10
import torch .nn as nn
@@ -31,9 +31,11 @@ def __init__(
31
31
num_input_features : int ,
32
32
growth_rate : int ,
33
33
bn_size : int ,
34
- norm_layer : type = BatchNormAct2d ,
34
+ norm_layer : Type [ nn . Module ] = BatchNormAct2d ,
35
35
drop_rate : float = 0. ,
36
36
grad_checkpointing : bool = False ,
37
+ device = None ,
38
+ dtype = None ,
37
39
) -> None :
38
40
"""Initialize DenseLayer.
39
41
@@ -45,13 +47,14 @@ def __init__(
45
47
drop_rate: Dropout rate.
46
48
grad_checkpointing: Use gradient checkpointing.
47
49
"""
50
+ dd = {'device' : device , 'dtype' : dtype }
48
51
super (DenseLayer , self ).__init__ ()
49
- self .add_module ('norm1' , norm_layer (num_input_features )),
52
+ self .add_module ('norm1' , norm_layer (num_input_features , ** dd )),
50
53
self .add_module ('conv1' , nn .Conv2d (
51
- num_input_features , bn_size * growth_rate , kernel_size = 1 , stride = 1 , bias = False )),
52
- self .add_module ('norm2' , norm_layer (bn_size * growth_rate )),
54
+ num_input_features , bn_size * growth_rate , kernel_size = 1 , stride = 1 , bias = False , ** dd )),
55
+ self .add_module ('norm2' , norm_layer (bn_size * growth_rate , ** dd )),
53
56
self .add_module ('conv2' , nn .Conv2d (
54
- bn_size * growth_rate , growth_rate , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )),
57
+ bn_size * growth_rate , growth_rate , kernel_size = 3 , stride = 1 , padding = 1 , bias = False , ** dd )),
55
58
self .drop_rate = float (drop_rate )
56
59
self .grad_checkpointing = grad_checkpointing
57
60
@@ -129,9 +132,11 @@ def __init__(
129
132
num_input_features : int ,
130
133
bn_size : int ,
131
134
growth_rate : int ,
132
- norm_layer : type = BatchNormAct2d ,
135
+ norm_layer : Type [ nn . Module ] = BatchNormAct2d ,
133
136
drop_rate : float = 0. ,
134
137
grad_checkpointing : bool = False ,
138
+ device = None ,
139
+ dtype = None ,
135
140
) -> None :
136
141
"""Initialize DenseBlock.
137
142
@@ -144,6 +149,7 @@ def __init__(
144
149
drop_rate: Dropout rate.
145
150
grad_checkpointing: Use gradient checkpointing.
146
151
"""
152
+ dd = {'device' : device , 'dtype' : dtype }
147
153
super (DenseBlock , self ).__init__ ()
148
154
for i in range (num_layers ):
149
155
layer = DenseLayer (
@@ -153,6 +159,7 @@ def __init__(
153
159
norm_layer = norm_layer ,
154
160
drop_rate = drop_rate ,
155
161
grad_checkpointing = grad_checkpointing ,
162
+ ** dd ,
156
163
)
157
164
self .add_module ('denselayer%d' % (i + 1 ), layer )
158
165
@@ -182,8 +189,10 @@ def __init__(
182
189
self ,
183
190
num_input_features : int ,
184
191
num_output_features : int ,
185
- norm_layer : type = BatchNormAct2d ,
186
- aa_layer : Optional [type ] = None ,
192
+ norm_layer : Type [nn .Module ] = BatchNormAct2d ,
193
+ aa_layer : Optional [Type [nn .Module ]] = None ,
194
+ device = None ,
195
+ dtype = None ,
187
196
) -> None :
188
197
"""Initialize DenseTransition.
189
198
@@ -193,12 +202,13 @@ def __init__(
193
202
norm_layer: Normalization layer class.
194
203
aa_layer: Anti-aliasing layer class.
195
204
"""
205
+ dd = {'device' : device , 'dtype' : dtype }
196
206
super (DenseTransition , self ).__init__ ()
197
- self .add_module ('norm' , norm_layer (num_input_features ))
207
+ self .add_module ('norm' , norm_layer (num_input_features , ** dd ))
198
208
self .add_module ('conv' , nn .Conv2d (
199
- num_input_features , num_output_features , kernel_size = 1 , stride = 1 , bias = False ))
209
+ num_input_features , num_output_features , kernel_size = 1 , stride = 1 , bias = False , ** dd ))
200
210
if aa_layer is not None :
201
- self .add_module ('pool' , aa_layer (num_output_features , stride = 2 ))
211
+ self .add_module ('pool' , aa_layer (num_output_features , stride = 2 , ** dd ))
202
212
else :
203
213
self .add_module ('pool' , nn .AvgPool2d (kernel_size = 2 , stride = 2 ))
204
214
@@ -231,11 +241,13 @@ def __init__(
231
241
stem_type : str = '' ,
232
242
act_layer : str = 'relu' ,
233
243
norm_layer : str = 'batchnorm2d' ,
234
- aa_layer : Optional [type ] = None ,
244
+ aa_layer : Optional [Type [ nn . Module ] ] = None ,
235
245
drop_rate : float = 0. ,
236
246
proj_drop_rate : float = 0. ,
237
247
memory_efficient : bool = False ,
238
248
aa_stem_only : bool = True ,
249
+ device = None ,
250
+ dtype = None ,
239
251
) -> None :
240
252
"""Initialize DenseNet.
241
253
@@ -255,6 +267,7 @@ def __init__(
255
267
memory_efficient: If True, uses checkpointing for memory efficiency.
256
268
aa_stem_only: Apply anti-aliasing only to stem.
257
269
"""
270
+ dd = {'device' : device , 'dtype' : dtype }
258
271
self .num_classes = num_classes
259
272
super (DenseNet , self ).__init__ ()
260
273
norm_layer = get_norm_act_layer (norm_layer , act_layer = act_layer )
@@ -267,25 +280,25 @@ def __init__(
267
280
else :
268
281
stem_pool = nn .Sequential (* [
269
282
nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
270
- aa_layer (channels = num_init_features , stride = 2 )])
283
+ aa_layer (channels = num_init_features , stride = 2 , ** dd )])
271
284
if deep_stem :
272
285
stem_chs_1 = stem_chs_2 = growth_rate
273
286
if 'tiered' in stem_type :
274
287
stem_chs_1 = 3 * (growth_rate // 4 )
275
288
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4 )
276
289
self .features = nn .Sequential (OrderedDict ([
277
- ('conv0' , nn .Conv2d (in_chans , stem_chs_1 , 3 , stride = 2 , padding = 1 , bias = False )),
278
- ('norm0' , norm_layer (stem_chs_1 )),
279
- ('conv1' , nn .Conv2d (stem_chs_1 , stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False )),
280
- ('norm1' , norm_layer (stem_chs_2 )),
281
- ('conv2' , nn .Conv2d (stem_chs_2 , num_init_features , 3 , stride = 1 , padding = 1 , bias = False )),
282
- ('norm2' , norm_layer (num_init_features )),
290
+ ('conv0' , nn .Conv2d (in_chans , stem_chs_1 , 3 , stride = 2 , padding = 1 , bias = False , ** dd )),
291
+ ('norm0' , norm_layer (stem_chs_1 , ** dd )),
292
+ ('conv1' , nn .Conv2d (stem_chs_1 , stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False , ** dd )),
293
+ ('norm1' , norm_layer (stem_chs_2 , ** dd )),
294
+ ('conv2' , nn .Conv2d (stem_chs_2 , num_init_features , 3 , stride = 1 , padding = 1 , bias = False , ** dd )),
295
+ ('norm2' , norm_layer (num_init_features , ** dd )),
283
296
('pool0' , stem_pool ),
284
297
]))
285
298
else :
286
299
self .features = nn .Sequential (OrderedDict ([
287
- ('conv0' , nn .Conv2d (in_chans , num_init_features , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )),
288
- ('norm0' , norm_layer (num_init_features )),
300
+ ('conv0' , nn .Conv2d (in_chans , num_init_features , kernel_size = 7 , stride = 2 , padding = 3 , bias = False , ** dd )),
301
+ ('norm0' , norm_layer (num_init_features , ** dd )),
289
302
('pool0' , stem_pool ),
290
303
]))
291
304
self .feature_info = [
@@ -303,6 +316,7 @@ def __init__(
303
316
norm_layer = norm_layer ,
304
317
drop_rate = proj_drop_rate ,
305
318
grad_checkpointing = memory_efficient ,
319
+ ** dd ,
306
320
)
307
321
module_name = f'denseblock{ (i + 1 )} '
308
322
self .features .add_module (module_name , block )
@@ -317,12 +331,13 @@ def __init__(
317
331
num_output_features = num_features // 2 ,
318
332
norm_layer = norm_layer ,
319
333
aa_layer = transition_aa_layer ,
334
+ ** dd ,
320
335
)
321
336
self .features .add_module (f'transition{ i + 1 } ' , trans )
322
337
num_features = num_features // 2
323
338
324
339
# Final batch norm
325
- self .features .add_module ('norm5' , norm_layer (num_features ))
340
+ self .features .add_module ('norm5' , norm_layer (num_features , ** dd ))
326
341
327
342
self .feature_info += [dict (num_chs = num_features , reduction = current_stride , module = 'features.norm5' )]
328
343
self .num_features = self .head_hidden_size = num_features
@@ -332,6 +347,7 @@ def __init__(
332
347
self .num_features ,
333
348
self .num_classes ,
334
349
pool_type = global_pool ,
350
+ ** dd ,
335
351
)
336
352
self .global_pool = global_pool
337
353
self .head_drop = nn .Dropout (drop_rate )
0 commit comments