-
Notifications
You must be signed in to change notification settings - Fork 3.5k
fixes typing in pytorch_lightning/callbacks/stochastic_weight_avg.py #13685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
31aa03f
39319e1
8d66cbc
e94b90a
26361ad
e0b7489
8f4961c
676b768
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,19 +16,19 @@ | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
""" | ||
from copy import deepcopy | ||
from typing import Callable, List, Optional, Union | ||
from typing import Any, Callable, cast, List, Optional, Union | ||
|
||
import torch | ||
from torch import FloatTensor, nn, Tensor | ||
from torch import nn, Tensor | ||
from torch.optim.swa_utils import SWALR | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks.callback import Callback | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn | ||
from pytorch_lightning.utilities.types import LRSchedulerConfig | ||
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig | ||
|
||
_AVG_FN = Callable[[Tensor, Tensor, torch.LongTensor], FloatTensor] | ||
_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] | ||
|
||
|
||
class StochasticWeightAveraging(Callback): | ||
|
@@ -106,7 +106,7 @@ def __init__( | |
if wrong_type or wrong_float or wrong_list: | ||
raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") | ||
|
||
if avg_fn is not None and not isinstance(avg_fn, Callable): | ||
if avg_fn is not None and not callable(avg_fn): | ||
raise MisconfigurationException("The `avg_fn` should be callable.") | ||
|
||
if device is not None and not isinstance(device, (torch.device, str)): | ||
|
@@ -118,27 +118,29 @@ def __init__( | |
self._annealing_strategy = annealing_strategy | ||
self._avg_fn = avg_fn or self.avg_fn | ||
self._device = device | ||
self._model_contains_batch_norm = None | ||
self._average_model = None | ||
self._max_epochs: int | ||
self._model_contains_batch_norm: bool | ||
self._average_model: "pl.LightningModule" | ||
|
||
@property | ||
def swa_start(self) -> int: | ||
assert isinstance(self._swa_epoch_start, int) | ||
return max(self._swa_epoch_start - 1, 0) # 0-based | ||
|
||
@property | ||
def swa_end(self) -> int: | ||
return self._max_epochs - 1 # 0-based | ||
|
||
@staticmethod | ||
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule"): | ||
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: | ||
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) | ||
|
||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: | ||
# copy the model before moving it to accelerator device. | ||
with pl_module._prevent_trainer_and_dataloaders_deepcopy(): | ||
self._average_model = deepcopy(pl_module) | ||
|
||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): | ||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
if len(trainer.optimizers) != 1: | ||
raise MisconfigurationException("SWA currently works with 1 `optimizer`.") | ||
|
||
|
@@ -155,7 +157,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): | |
# virtually increase max_epochs to perform batch norm update on latest epoch. | ||
trainer.fit_loop.max_epochs += 1 | ||
|
||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): | ||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
if trainer.current_epoch == self.swa_start: | ||
# move average model to request device. | ||
self._average_model = self._average_model.to(self._device or pl_module.device) | ||
|
@@ -167,12 +169,15 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo | |
for lr, group in zip(self._swa_lrs, optimizer.param_groups): | ||
group["initial_lr"] = lr | ||
|
||
self._swa_scheduler = SWALR( | ||
optimizer, | ||
swa_lr=self._swa_lrs, | ||
anneal_epochs=self._annealing_epochs, | ||
anneal_strategy=self._annealing_strategy, | ||
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, | ||
self._swa_scheduler: _LRScheduler = cast( | ||
_LRScheduler, | ||
SWALR( | ||
optimizer, | ||
swa_lr=self._swa_lrs, # type: ignore[arg-type] | ||
anneal_epochs=self._annealing_epochs, | ||
anneal_strategy=self._annealing_strategy, | ||
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, | ||
), | ||
) | ||
# We assert that there is only one optimizer on fit start, so know opt_idx is always 0 | ||
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) | ||
|
@@ -213,10 +218,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo | |
|
||
trainer.accumulate_grad_batches = trainer.num_training_batches | ||
|
||
def on_train_epoch_end(self, trainer: "pl.Trainer", *args): | ||
def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: | ||
trainer.fit_loop._skip_backward = False | ||
|
||
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): | ||
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
# the trainer increases the current epoch before this hook is called | ||
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: | ||
# BatchNorm epoch update. Reset state | ||
|
@@ -229,35 +234,39 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): | |
self.transfer_weights(self._average_model, pl_module) | ||
|
||
@staticmethod | ||
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"): | ||
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None: | ||
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): | ||
dst_param.detach().copy_(src_param.to(dst_param.device)) | ||
|
||
def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"): | ||
def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None: | ||
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" | ||
self.momenta = {} | ||
for module in pl_module.modules(): | ||
if not isinstance(module, nn.modules.batchnorm._BatchNorm): | ||
continue | ||
module.running_mean = torch.zeros_like( | ||
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype | ||
module.running_mean, # type: ignore[arg-type] | ||
device=pl_module.device, | ||
dtype=module.running_mean.dtype, # type: ignore[union-attr] | ||
) | ||
module.running_var = torch.ones_like( | ||
module.running_var, device=pl_module.device, dtype=module.running_var.dtype | ||
module.running_var, # type: ignore[arg-type] | ||
device=pl_module.device, | ||
dtype=module.running_var.dtype, # type: ignore[union-attr] | ||
Comment on lines
+248
to
+255
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO these shouldn't have been ignored, but instead added assertions that |
||
) | ||
self.momenta[module] = module.momentum | ||
module.momentum = None | ||
module.num_batches_tracked *= 0 | ||
module.momentum = None # type: ignore[assignment] | ||
module.num_batches_tracked *= 0 # type: ignore[assignment, operator] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here about asserting that it's not None |
||
|
||
def reset_momenta(self): | ||
def reset_momenta(self) -> None: | ||
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" | ||
for bn_module in self.momenta: | ||
bn_module.momentum = self.momenta[bn_module] | ||
|
||
@staticmethod | ||
def update_parameters( | ||
average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: torch.LongTensor, avg_fn: _AVG_FN | ||
): | ||
average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN | ||
) -> None: | ||
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" | ||
for p_swa, p_model in zip(average_model.parameters(), model.parameters()): | ||
device = p_swa.device | ||
|
@@ -268,8 +277,6 @@ def update_parameters( | |
n_averaged += 1 | ||
|
||
@staticmethod | ||
def avg_fn( | ||
averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: torch.LongTensor | ||
) -> FloatTensor: | ||
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: | ||
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" | ||
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation is redundant when you cast