Skip to content

fixes typing in pytorch_lightning/callbacks/stochastic_weight_avg.py #13685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 26, 2022
Merged

fixes typing in pytorch_lightning/callbacks/stochastic_weight_avg.py #13685

merged 8 commits into from
Jul 26, 2022

Conversation

donlapark
Copy link
Contributor

@donlapark donlapark commented Jul 16, 2022

What does this PR do?

Fixes typing in pytorch in pytorch_lightning/callbacks/stochastic_weight_avg.py from #13445.

Note:

  • Line 109: there has been a known issue of isinstance(..., Callable) throwing an arg-type error (isinstance(..., collections.abc.Callable) complains about "_SpecialForm" python/mypy#6864).
  • Line 127: If self._swa_epoch_start is a float, it is expected to be converted to int by a call to on_fit_start(). However, mypy does not see this and throws in a return-value error.
  • Line 175: torch.optim.swa_utils.SWALR has not been type-annotated yet, so the type of the argument of swa_lr is automatically set to float.
  • Line 246-258: _BatchNorm.running_mean, _BatchNorm.running_var and _BatchNorm.num_batches_tracked have types Optional[Tensor]. These attributes are initialized to be zero tensors via self.reset_parameters() but mypy does not see this.

Does your PR introduce any breaking changes? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@otaj otaj mentioned this pull request Jul 18, 2022
52 tasks
donlap added 2 commits July 18, 2022 16:13
…m:donlapark/lightning into fix_typing_callback_stochastic_weight_avg

merge with prior changes
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jul 20, 2022
@codecov
Copy link

codecov bot commented Jul 20, 2022

Codecov Report

Merging #13685 (676b768) into master (238c991) will decrease coverage by 12%.
The diff coverage is 100%.

@@            Coverage Diff            @@
##           master   #13685     +/-   ##
=========================================
- Coverage      86%      74%    -12%     
=========================================
  Files         327      330      +3     
  Lines       25505    27641   +2136     
=========================================
- Hits        21903    20475   -1428     
- Misses       3602     7166   +3564     

@otaj otaj enabled auto-merge (squash) July 25, 2022 07:54
@mergify mergify bot added the ready PRs ready to be merged label Jul 25, 2022
@carmocca carmocca modified the milestones: pl:1.7, pl:1.8 Jul 25, 2022
@donlapark donlapark closed this Jul 25, 2022
auto-merge was automatically disabled July 25, 2022 09:47

Pull request was closed

@donlapark donlapark deleted the fix_typing_callback_stochastic_weight_avg branch July 25, 2022 09:47
@donlapark donlapark restored the fix_typing_callback_stochastic_weight_avg branch July 25, 2022 09:49
@donlapark donlapark reopened this Jul 25, 2022
@otaj otaj enabled auto-merge (squash) July 25, 2022 12:35
@otaj otaj merged commit e77accf into Lightning-AI:master Jul 26, 2022
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can open a follow-up PR with the suggestions applied

Comment on lines +248 to +255
module.running_mean, # type: ignore[arg-type]
device=pl_module.device,
dtype=module.running_mean.dtype, # type: ignore[union-attr]
)
module.running_var = torch.ones_like(
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
module.running_var, # type: ignore[arg-type]
device=pl_module.device,
dtype=module.running_var.dtype, # type: ignore[union-attr]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO these shouldn't have been ignored, but instead added assertions that running_{mean,var} are not None

module.momentum = None
module.num_batches_tracked *= 0
module.momentum = None # type: ignore[assignment]
module.num_batches_tracked *= 0 # type: ignore[assignment, operator]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about asserting that it's not None

anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
self._swa_scheduler: _LRScheduler = cast(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation is redundant when you cast

Suggested change
self._swa_scheduler: _LRScheduler = cast(
self._swa_scheduler = cast(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback: swa code quality community This PR is from the community pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants