Skip to content

Commit a37fc72

Browse files
authored
fixes typing in stochastic_weight_avg.py (follow-up of #13685) (#13860)
1 parent c391170 commit a37fc72

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
169169
for lr, group in zip(self._swa_lrs, optimizer.param_groups):
170170
group["initial_lr"] = lr
171171

172-
self._swa_scheduler: _LRScheduler = cast(
172+
self._swa_scheduler = cast(
173173
_LRScheduler,
174174
SWALR(
175175
optimizer,
@@ -244,19 +244,22 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No
244244
for module in pl_module.modules():
245245
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
246246
continue
247+
assert module.running_mean is not None
247248
module.running_mean = torch.zeros_like(
248-
module.running_mean, # type: ignore[arg-type]
249+
module.running_mean,
249250
device=pl_module.device,
250-
dtype=module.running_mean.dtype, # type: ignore[union-attr]
251+
dtype=module.running_mean.dtype,
251252
)
253+
assert module.running_var is not None
252254
module.running_var = torch.ones_like(
253-
module.running_var, # type: ignore[arg-type]
255+
module.running_var,
254256
device=pl_module.device,
255-
dtype=module.running_var.dtype, # type: ignore[union-attr]
257+
dtype=module.running_var.dtype,
256258
)
257259
self.momenta[module] = module.momentum
258-
module.momentum = None # type: ignore[assignment]
259-
module.num_batches_tracked *= 0 # type: ignore[assignment, operator]
260+
module.momentum = float()
261+
assert module.num_batches_tracked is not None
262+
module.num_batches_tracked *= 0
260263

261264
def reset_momenta(self) -> None:
262265
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""

0 commit comments

Comments
 (0)