Skip to content

fixes typing in pytorch_lightning/callbacks/stochastic_weight_avg.py #13685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.callbacks.stochastic_weight_avg",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.mixins.device_dtype_mixin",
Expand Down
67 changes: 37 additions & 30 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
from typing import Callable, List, Optional, Union
from typing import Any, Callable, cast, List, Optional, Union

import torch
from torch import FloatTensor, nn, Tensor
from torch import nn, Tensor
from torch.optim.swa_utils import SWALR

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig

_AVG_FN = Callable[[Tensor, Tensor, torch.LongTensor], FloatTensor]
_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]


class StochasticWeightAveraging(Callback):
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
if wrong_type or wrong_float or wrong_list:
raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats")

if avg_fn is not None and not isinstance(avg_fn, Callable):
if avg_fn is not None and not callable(avg_fn):
raise MisconfigurationException("The `avg_fn` should be callable.")

if device is not None and not isinstance(device, (torch.device, str)):
Expand All @@ -118,27 +118,29 @@ def __init__(
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._model_contains_batch_norm = None
self._average_model = None
self._max_epochs: int
self._model_contains_batch_norm: bool
self._average_model: "pl.LightningModule"

@property
def swa_start(self) -> int:
assert isinstance(self._swa_epoch_start, int)
return max(self._swa_epoch_start - 1, 0) # 0-based

@property
def swa_end(self) -> int:
return self._max_epochs - 1 # 0-based

@staticmethod
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule"):
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
# copy the model before moving it to accelerator device.
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self._average_model = deepcopy(pl_module)

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if len(trainer.optimizers) != 1:
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")

Expand All @@ -155,7 +157,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.fit_loop.max_epochs += 1

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.current_epoch == self.swa_start:
# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)
Expand All @@ -167,12 +169,15 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
for lr, group in zip(self._swa_lrs, optimizer.param_groups):
group["initial_lr"] = lr

self._swa_scheduler = SWALR(
optimizer,
swa_lr=self._swa_lrs,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
self._swa_scheduler: _LRScheduler = cast(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation is redundant when you cast

Suggested change
self._swa_scheduler: _LRScheduler = cast(
self._swa_scheduler = cast(

_LRScheduler,
SWALR(
optimizer,
swa_lr=self._swa_lrs, # type: ignore[arg-type]
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
),
)
# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
Expand Down Expand Up @@ -213,10 +218,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

trainer.accumulate_grad_batches = trainer.num_training_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
trainer.fit_loop._skip_backward = False

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# the trainer increases the current epoch before this hook is called
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
# BatchNorm epoch update. Reset state
Expand All @@ -229,35 +234,39 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
self.transfer_weights(self._average_model, pl_module)

@staticmethod
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"):
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None:
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))

def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"):
def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
self.momenta = {}
for module in pl_module.modules():
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
continue
module.running_mean = torch.zeros_like(
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
module.running_mean, # type: ignore[arg-type]
device=pl_module.device,
dtype=module.running_mean.dtype, # type: ignore[union-attr]
)
module.running_var = torch.ones_like(
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
module.running_var, # type: ignore[arg-type]
device=pl_module.device,
dtype=module.running_var.dtype, # type: ignore[union-attr]
Comment on lines +248 to +255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO these shouldn't have been ignored, but instead added assertions that running_{mean,var} are not None

)
self.momenta[module] = module.momentum
module.momentum = None
module.num_batches_tracked *= 0
module.momentum = None # type: ignore[assignment]
module.num_batches_tracked *= 0 # type: ignore[assignment, operator]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about asserting that it's not None


def reset_momenta(self):
def reset_momenta(self) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module]

@staticmethod
def update_parameters(
average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: torch.LongTensor, avg_fn: _AVG_FN
):
average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN
) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112."""
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
device = p_swa.device
Expand All @@ -268,8 +277,6 @@ def update_parameters(
n_averaged += 1

@staticmethod
def avg_fn(
averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: torch.LongTensor
) -> FloatTensor:
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)