@@ -169,7 +169,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
169
169
for lr , group in zip (self ._swa_lrs , optimizer .param_groups ):
170
170
group ["initial_lr" ] = lr
171
171
172
- self ._swa_scheduler : _LRScheduler = cast (
172
+ self ._swa_scheduler = cast (
173
173
_LRScheduler ,
174
174
SWALR (
175
175
optimizer ,
@@ -244,19 +244,22 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No
244
244
for module in pl_module .modules ():
245
245
if not isinstance (module , nn .modules .batchnorm ._BatchNorm ):
246
246
continue
247
+ assert module .running_mean is not None
247
248
module .running_mean = torch .zeros_like (
248
- module .running_mean , # type: ignore[arg-type]
249
+ module .running_mean ,
249
250
device = pl_module .device ,
250
- dtype = module .running_mean .dtype , # type: ignore[union-attr]
251
+ dtype = module .running_mean .dtype ,
251
252
)
253
+ assert module .running_var is not None
252
254
module .running_var = torch .ones_like (
253
- module .running_var , # type: ignore[arg-type]
255
+ module .running_var ,
254
256
device = pl_module .device ,
255
- dtype = module .running_var .dtype , # type: ignore[union-attr]
257
+ dtype = module .running_var .dtype ,
256
258
)
257
259
self .momenta [module ] = module .momentum
258
- module .momentum = None # type: ignore[assignment]
259
- module .num_batches_tracked *= 0 # type: ignore[assignment, operator]
260
+ module .momentum = float ()
261
+ assert module .num_batches_tracked is not None
262
+ module .num_batches_tracked *= 0
260
263
261
264
def reset_momenta (self ) -> None :
262
265
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
0 commit comments