-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added progress flag to model getters #875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #875 +/- ##
=========================================
+ Coverage 55.17% 55.9% +0.73%
=========================================
Files 36 37 +1
Lines 3375 3338 -37
Branches 553 531 -22
=========================================
+ Hits 1862 1866 +4
+ Misses 1375 1357 -18
+ Partials 138 115 -23
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
* Minor refactoring of ShuffleNetV2 Added progress flag following #875. Further the following refactoring was also done: 1) added `version` argument in shufflenetv2 method and removed the operations for converting the `width_mult` arg to float and string. 2) removed `num_classes` argument and **kwargs from functions except `ShuffleNetV2` * removed `version` arg * Update shufflenetv2.py * Removed the try except block * Update shufflenetv2.py * Changed version from float to str * Replace `width_mult` with `stages_out_channels` Removes the need of `_getStages` function.
This PR uses a protected method for loading and initializing the segmentation models. Relevant #875
This adds support for disabling the display of the download progress as requested in #862. It depends on
torch.hub.load_state_dict_from_url
, which is not included in the latest stable release. Thus, this PR will not pass the CI until #826 is merged as @fmassa mentioned.I also did some refactoring along the way:
model
s with multiple different architectures (DenseNet
,ResNet
,SqueezeNet
, andVGG
) I added a protected function that allmodel
getter functions use.version
identifier inSqueezeNet
fromfloat
tostr
(e.g.1.0
to'1_0'
).Edit: Apparently this passed CI checks. Beats me why though.