-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Refactoring of ShuffleNetV2 #889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Added progress flag following #875. Further the following refactoring was also done: 1) added `version` argument in shufflenetv2 method and removed the operations for converting the `width_mult` arg to float and string. 2) removed `num_classes` argument and **kwargs from functions except `ShuffleNetV2`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
I think that in general MobileNet and ShuffleNet they all follow the same base architectures (as explained in https://arxiv.org/abs/1812.03443 ), so I wonder if at some point they shouldn't just both use the same base implementation.
Also ccing @newstzpz who has worked on FBNet
Codecov Report
@@ Coverage Diff @@
## master #889 +/- ##
==========================================
- Coverage 56.7% 56.69% -0.01%
==========================================
Files 38 38
Lines 3432 3422 -10
Branches 540 539 -1
==========================================
- Hits 1946 1940 -6
+ Misses 1370 1365 -5
- Partials 116 117 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ekagra-ranjan thanks for your proposal.
A few queries:
- Why remove the num_classes argument? Having the model decoupled from a specific dataset sounds quite useful to me. Is it mandatory, in other words, is there a standard function type you're trying to confirm to?
The reuse of the term 'version' in the context of Shufflenet which has already two versions (where this is ShuffleNet V2) should be avoided for clarity-sake, imo.(Edited: doesn't appear in most recent commit)- Removing the explicit conversion (float()) will make it very confusing should someone try to create the model without going through the utility functions, e.g.
ShufflenetV2()
calls will fail (butShufflenetV2(width_mult=1.0)
won't). The explicit conversion is crucial.
@fmassa is torchvision imagenet 1k specific? |
@barrh the models in The change that @ekagra-ranjan did do not remove support for passing a different number of classes. Instead, it follows the pattern from other files, where if vision/torchvision/models/resnet.py Line 238 in ce18507
You are right, but one solution to this would be to change the default argument from In general, I'm not very keen on having the version be a float value. This can lead to a number of problems down the road, because the comparison between floating points, see #728 (comment) for example. I think it might just be simpler to use a string instead to represent the |
@fmassa thanks for elaborated reply. I think having num_classes argument in the standard utility function is the way to go for the long run (even for resnet50), however, I see why it makes sense to confirm with resnet50 type of call. With this patch |
@barrh I think that a better solution is to let the constructor of the model The reason for that is that it's more generic on the user perspective, and avoids all the problems we have been discussing so far. This is also btw what's done for Thoughts? |
sounds great. |
@barrh ok, looks like we are getting into a consensus. |
@ekagra-ranjan given that you have already addressed most of the comments, to make things simpler would you mind replacing the |
Sure! |
@barrh Thank you for your views. I addressed your concern about
I agree to this. The documentation does not have any mention about When someone has to use a pretrained model and finetune it they use the following syntax:
Notice that nowhere was One solution to address this problem would be to perform the above operation of replacing the fc layer within the code of the model. In the case of ResNets, the protected function
This way we can make the docs more informative by adding What is your view on this @fmassa ? |
Removes the need of `_getStages` function.
@ekagra-ranjan please check out the PR I pushed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just waiting for CI to finish
Added progress flag following #875. Further the following refactoring were also done:
width_mult
arg to float.num_classes
argument from functions exceptShuffleNetV2
try expect
block for handling errors for absent pretrained weights andwidth_mult
mult_width
withstages_out_channels
.