Skip to content

Refactoring of ShuffleNetV2 #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 7, 2019
Merged

Refactoring of ShuffleNetV2 #889

merged 7 commits into from
May 7, 2019

Conversation

ekagra-ranjan
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan commented May 6, 2019

Added progress flag following #875. Further the following refactoring were also done:

  1. removed the operations for converting the width_mult arg to float.
  2. removed num_classes argument from functions except ShuffleNetV2
  3. removed try expect block for handling errors for absent pretrained weights and width_mult
  4. replace mult_width with stages_out_channels.

Added progress flag following #875. Further the following refactoring was also done:

1) added `version` argument in shufflenetv2 method and removed the operations for converting the `width_mult` arg to float and string.
2) removed `num_classes` argument and **kwargs from functions except `ShuffleNetV2`
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

I think that in general MobileNet and ShuffleNet they all follow the same base architectures (as explained in https://arxiv.org/abs/1812.03443 ), so I wonder if at some point they shouldn't just both use the same base implementation.

Also ccing @newstzpz who has worked on FBNet

@codecov-io
Copy link

codecov-io commented May 6, 2019

Codecov Report

Merging #889 into master will decrease coverage by <.01%.
The diff coverage is 77.77%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #889      +/-   ##
==========================================
- Coverage    56.7%   56.69%   -0.01%     
==========================================
  Files          38       38              
  Lines        3432     3422      -10     
  Branches      540      539       -1     
==========================================
- Hits         1946     1940       -6     
+ Misses       1370     1365       -5     
- Partials      116      117       +1
Impacted Files Coverage Δ
torchvision/models/shufflenetv2.py 90.36% <77.77%> (+6.49%) ⬆️
torchvision/transforms/transforms.py 81.89% <0%> (-0.65%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update dc7db17...3029b52. Read the comment docs.

Copy link
Contributor

@barrh barrh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan thanks for your proposal.

A few queries:

  1. Why remove the num_classes argument? Having the model decoupled from a specific dataset sounds quite useful to me. Is it mandatory, in other words, is there a standard function type you're trying to confirm to?
  2. The reuse of the term 'version' in the context of Shufflenet which has already two versions (where this is ShuffleNet V2) should be avoided for clarity-sake, imo. (Edited: doesn't appear in most recent commit)
  3. Removing the explicit conversion (float()) will make it very confusing should someone try to create the model without going through the utility functions, e.g. ShufflenetV2() calls will fail (but ShufflenetV2(width_mult=1.0) won't). The explicit conversion is crucial.

@barrh
Copy link
Contributor

barrh commented May 6, 2019

@fmassa is torchvision imagenet 1k specific?
Wouldn't it be better to have dataset-agnostic models?

@fmassa
Copy link
Member

fmassa commented May 7, 2019

@barrh the models in torchvision are not imagenet 1k specific, and I'm as of now adding support for other tasks which are not trained on imagenet.

The change that @ekagra-ranjan did do not remove support for passing a different number of classes. Instead, it follows the pattern from other files, where if **kwargs can accept num_classes

def resnet50(pretrained=False, progress=True, **kwargs):

Removing the explicit conversion (float()) will make it very confusing should someone try to create the model without going through the utility functions, e.g. ShufflenetV2() calls will fail (but ShufflenetV2(width_mult=1.0) won't). The explicit conversion is crucial.

You are right, but one solution to this would be to change the default argument from 1 to 1.0 in the class.

In general, I'm not very keen on having the version be a float value. This can lead to a number of problems down the road, because the comparison between floating points, see #728 (comment) for example. I think it might just be simpler to use a string instead to represent the width_mult, as there is really no reason why it's a float, as it is only used for mapping the number of layers.

@barrh
Copy link
Contributor

barrh commented May 7, 2019

@fmassa thanks for elaborated reply.

I think having num_classes argument in the standard utility function is the way to go for the long run (even for resnet50), however, I see why it makes sense to confirm with resnet50 type of call.

With this patch ShufflenetV2(width_mult=1) fails. With changing width_mult to a string type: ShufflenetV2(width_mult='1') will fail.
Although, I agree explicit conversions are generally bad, I don't see an immediate benefit from removing it. Perhaps we need to expose subclasses of ShufflenetV2{0.5,1,1.5,2}, instead of ShufflenetV2.

@fmassa
Copy link
Member

fmassa commented May 7, 2019

@barrh I think that a better solution is to let the constructor of the model ShuffleNetV2 to take as argument the stage_out_channels, and construct each one of the stages inside the corresponding shufflenetv2_x0_5 etc functions.

The reason for that is that it's more generic on the user perspective, and avoids all the problems we have been discussing so far. This is also btw what's done for resnet, and makes it easier for the user to try out different configurations.

Thoughts?

@barrh
Copy link
Contributor

barrh commented May 7, 2019

sounds great.
There's a lot of room to make this model more generic, specifically, stage_repeats is more interesting than stage_out_channels imo: https://github.com/pytorch/vision/blob/master/torchvision/models/shufflenetv2.py#L108

@fmassa
Copy link
Member

fmassa commented May 7, 2019

@barrh ok, looks like we are getting into a consensus.
I think there are a few things that can be improved in shufflenet, as you point out. Would you mind sending a follow-up PR, addressing the comments I've mentioned here? Then I can merge this PR now to move faster.

@fmassa
Copy link
Member

fmassa commented May 7, 2019

@ekagra-ranjan given that you have already addressed most of the comments, to make things simpler would you mind replacing the width_mult string argument in ShuffleNetV2 with a layers or something like that, which is a list of ints, and replace getStages altogether?

@ekagra-ranjan
Copy link
Contributor Author

Sure!

@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented May 7, 2019

@barrh Thank you for your views. I addressed your concern about width_mult arg by converting width_mult to str as suggested by @fmassa.

I think having num_classes argument in the standard utility function is the way to go for the long run (even for resnet50), however, I see why it makes sense to confirm with resnet50 type of call.

I agree to this. The documentation does not have any mention about num_classes as an argument to the torchvision models. @fmassa Do you think that we should mention num_classes as an optional argument in the docs? This would address the above highlighted problem mentioned by @barrh although it has the following problem:

When someone has to use a pretrained model and finetune it they use the following syntax:

model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

Notice that nowhere was num_classes passed during the defining the object of resnet18. The num_classes was specified as something after the creation of the object of the model. So the only scenario when the user would num_classes to model constructor would be when they are not using the pretrained weights i.e: when pretrained=False. So just mentioning the optional argument num_classes in the docs can be misleading without mentioning when to not change the default value of num_classes.

One solution to address this problem would be to perform the above operation of replacing the fc layer within the code of the model. In the case of ResNets, the protected function _resnet can be modified as:

def _resnet(arch, inplanes, planes, pretrained, num_classes, progress, **kwargs):
    
    if pretrained:
        model = ResNet(inplanes, planes, **kwargs)
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)

        if num_classes!=1000:
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, num_classes)

    else:
        model = ResNet(inplanes, planes, num_classes, **kwargs)

    return model

This way we can make the docs more informative by adding num_classes as an optional argument in the docs without any conditioning on the use of num_classes. This also simplifies using any torchvision model for finetuning by just passing a single parameter i.e: num_classes.

What is your view on this @fmassa ?

@barrh barrh mentioned this pull request May 7, 2019
Removes the need of  `_getStages` function.
@barrh
Copy link
Contributor

barrh commented May 7, 2019

@ekagra-ranjan please check out the PR I pushed.

@ekagra-ranjan ekagra-ranjan changed the title Minor refactoring of ShuffleNetV2 Refactoring of ShuffleNetV2 May 7, 2019
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just waiting for CI to finish

@fmassa fmassa merged commit 0564df4 into pytorch:master May 7, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants