-
Notifications
You must be signed in to change notification settings - Fork 3.5k
fixes typing in pytorch_lightning/callbacks/stochastic_weight_avg.py #13685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixes typing in pytorch_lightning/callbacks/stochastic_weight_avg.py #13685
Conversation
for more information, see https://pre-commit.ci
…m:donlapark/lightning into fix_typing_callback_stochastic_weight_avg merge with prior changes
Codecov Report
@@ Coverage Diff @@
## master #13685 +/- ##
=========================================
- Coverage 86% 74% -12%
=========================================
Files 327 330 +3
Lines 25505 27641 +2136
=========================================
- Hits 21903 20475 -1428
- Misses 3602 7166 +3564 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can open a follow-up PR with the suggestions applied
module.running_mean, # type: ignore[arg-type] | ||
device=pl_module.device, | ||
dtype=module.running_mean.dtype, # type: ignore[union-attr] | ||
) | ||
module.running_var = torch.ones_like( | ||
module.running_var, device=pl_module.device, dtype=module.running_var.dtype | ||
module.running_var, # type: ignore[arg-type] | ||
device=pl_module.device, | ||
dtype=module.running_var.dtype, # type: ignore[union-attr] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO these shouldn't have been ignored, but instead added assertions that running_{mean,var}
are not None
module.momentum = None | ||
module.num_batches_tracked *= 0 | ||
module.momentum = None # type: ignore[assignment] | ||
module.num_batches_tracked *= 0 # type: ignore[assignment, operator] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here about asserting that it's not None
anneal_epochs=self._annealing_epochs, | ||
anneal_strategy=self._annealing_strategy, | ||
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, | ||
self._swa_scheduler: _LRScheduler = cast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation is redundant when you cast
self._swa_scheduler: _LRScheduler = cast( | |
self._swa_scheduler = cast( |
What does this PR do?
Fixes typing in pytorch in
pytorch_lightning/callbacks/stochastic_weight_avg.py
from #13445.Note:
isinstance(..., Callable)
throwing an arg-type error (isinstance(..., collections.abc.Callable) complains about "_SpecialForm" python/mypy#6864).self._swa_epoch_start
is afloat
, it is expected to be converted toint
by a call toon_fit_start()
. However,mypy
does not see this and throws in a return-value error.torch.optim.swa_utils.SWALR
has not been type-annotated yet, so the type of the argument ofswa_lr
is automatically set tofloat
._BatchNorm.running_mean
,_BatchNorm.running_var
and_BatchNorm.num_batches_tracked
have typesOptional[Tensor]
. These attributes are initialized to be zero tensors viaself.reset_parameters()
butmypy
does not see this.Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃