Skip to content

[proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the most frequent cases #3331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vadimkantorov opened this issue Jan 31, 2021 · 23 comments

Comments

@vadimkantorov
Copy link

vadimkantorov commented Jan 31, 2021

Currently In https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L243:

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

If it instead used x = self.flatten(x), then it would simplify model surgery: del model.avgpool, model.flatten, model.fc. Also in this case the class can just derive from Sequential and use OrderedDict to pass submodules (like in https://discuss.pytorch.org/t/ux-mix-of-nn-sequential-and-nn-moduledict/104724/2?u=vadimkantorov), this would preserve checkpoint compat as well. The method forward could then be removed

@datumbox
Copy link
Contributor

datumbox commented Feb 1, 2021

Thanks for the proposal. Replacing torch.flatten with nn.Flatten seems it won't break anything but could you describe a bit more your use-case and how it will help you achieve what you want?

Concerning doing a major refactoring and inheriting from Sequential, this might have side-effects on all the other models that depend on resnet (segmentation, object detection etc), so I'm not sure if it can be done in a backward-compatible manner. Note also that replacing the forward is not something we can do due to how Quantization works:

def forward(self, x):
x = self.quant(x)
# Ensure scriptability
# super(QuantizableResNet,self).forward(x)
# is not scriptable
x = self._forward_impl(x)
x = self.dequant(x)
return x

@fmassa Let me know if you see any problems about replacing flatten.

@fmassa
Copy link
Member

fmassa commented Feb 1, 2021

I would be fine replacing torch.flatten with nn.Flatten, although it would only simplify model surgery if we were to make further modifications to the ResNet model as you suggested (making it inherit from nn.Sequential). Without this, only del model.fc etc wouldn't be enough.

So from this perspective, there is limited value in replacing torch.flatten with nn.Flatten, it being only syntactic sugar from one another.

@vadimkantorov
Copy link
Author

vadimkantorov commented Feb 1, 2021

Oh, I didn't realize:

# Ensure scriptability 
# super(QuantizableResNet,self).forward(x) 
# is not scriptable 

If this becomes fixed, I guess inheriting from nn.Sequential should be fine.

My real usecase: I'm working with a codebase https://github.com/ajabri/videowalk/blob/master/code/resnet.py#L41 that had to reimplement resnet forward because it wants to do simple model surgery

@fmassa
Copy link
Member

fmassa commented Feb 2, 2021

@vadimkantorov model surgery in PyTorch generally requires re-writing forward or other parts of the model, except for nn.Sequential, so I would say it's a valid requirement to ask the users to be a bit more verbose

@vadimkantorov
Copy link
Author

model surgery in PyTorch generally requires re-writing forward or other parts of the model, except for nn.Sequential

Of course. I understand that in general it's the case, but if it can be made simpler, I think it should be made simpler even if in general case full re-writing is required. In this particular case, it seems that because of quantization scripting limitations deriving from nn.Sequential is a no-go, but when that's solved it could be nice to derive ResNet from nn.Sequential.

I'll rename thie issue accordingly

@vadimkantorov vadimkantorov changed the title [proposal] Use nn.Flatten in ResNet [proposal] Derive ResNet from nn.Sequential when it becomes possible (scripting+quantization is blocker), would simplify model surgery in the simplest cases Feb 2, 2021
@vadimkantorov
Copy link
Author

vadimkantorov commented Feb 10, 2021

It seems that deriving from nn.Sequential would also be good from DeepSpeed-compat standpoint: pytorch/pytorch#51574

For quantization case, couldn't it just insert quant/dequant-calling modules in the beginning and the end of sequential? then everything should work and no modification of forward would be needed

@vadimkantorov
Copy link
Author

vadimkantorov commented Feb 10, 2021

They are already modules:

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

One would just be required to modify the constructor and insert them before conv1 and after layer4

@vadimkantorov vadimkantorov changed the title [proposal] Derive ResNet from nn.Sequential when it becomes possible (scripting+quantization is blocker), would simplify model surgery in the simplest cases [proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the simplest cases Jun 6, 2021
@vadimkantorov vadimkantorov changed the title [proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the simplest cases [proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the most frequent cases Jun 6, 2021
@vadimkantorov
Copy link
Author

vadimkantorov commented Jun 6, 2021

@fmassa Even if we just replace torch.flatten by self.flatten, it already simplifies model surgery by model.avgpool = model.flatten = model.fc = nn.Identity() (instead of del model.avgpool, model.flatten, model.fc)

@fmassa
Copy link
Member

fmassa commented Jun 7, 2021

@vadimkantorov I think we should revisit how one does model surgery in PyTorch. In #3597 I have a prototype feature (which I'll be updating soon) that has as a much more robust and generic way of providing feature extraction

As it relies on FX, it and quantization is also relying on FX for extracting the graph, I think this could be a good compromise which would address all your points I believe

@vadimkantorov
Copy link
Author

I think deleting some upper layers is only tangentially related to intermediate feature extraction, as with model.avgpool = model.flatten = model.fc = nn.Identity() is a very simple and understandable way removing some layers (and not involving a recent new technology).

I propose to still do self.flatten, even if you merge your new FX-based feature

@fmassa
Copy link
Member

fmassa commented Jun 7, 2021

@vadimkantorov the magic with the FX-based approach is that if you specify any part of the model, you can actually delete (by removing the compute and the parameters) of the end of the model. So it is actually a generalization of what you are proposing

@vadimkantorov
Copy link
Author

vadimkantorov commented Jun 7, 2021

I understand that it also solves this case :) But self.flatten is so simple and doesn't break any back-compat and doesn't require learning the new compiling functionality just for this simple goal :)

I think both are worthwhile :) (And maybe even figuring out how to derive from nn.Sequential for DeepSpeed benefits as well, but that's separate)

And probably the FX would still forward through the "layers-to-be-removed", and the large fully-connected layer can be quite costly

@fmassa
Copy link
Member

fmassa commented Sep 8, 2021

@vadimkantorov we merged #4302, which should provide a generic way of performing model surgery.

Could you give this a try and provide us feedback if it is enough for your use-cases?

@vadimkantorov
Copy link
Author

Very simple: remove average pooling, remove the last fc layer. I feel that having to plug generic FX / rewriting for this is an overkill

@fmassa
Copy link
Member

fmassa commented Sep 8, 2021

@vadimkantorov the thing you just mentioned can be done in one line

model_surgery = create_feature_extractor(model, ['layer4'])

or more generically

nodes, _ = get_graph_node_names(model)
model_surgery = create_feature_extractor(model, nodes[-3])

@vadimkantorov
Copy link
Author

vadimkantorov commented Sep 8, 2021

I don't even need the fc layer to run. Will it still run with it? Or will it do DCE?

I mean, I understand that FX is also a way to achieve this, but just being able to del backbone.fc or backbone.fc = nn.Identity() is so much simpler way to achieve this for simple models.

@fmassa
Copy link
Member

fmassa commented Sep 8, 2021

The FC layer (and its parameters) will be DCEd from the graph and won't be executed, so this is taken care for you.

@vadimkantorov
Copy link
Author

vadimkantorov commented Sep 8, 2021

I mean it's great to have a generic solution, but having an equivalent much simpler way (when possible) that users already know how to use is also a win

Then it also probably will be less debuggable. E.g. can I put a breakpoint in transformed code?

@fmassa
Copy link
Member

fmassa commented Sep 8, 2021

The transformed code can be printed etc, and is executed as standard Python code by the Python interpreter, so you can jump to every line of it in a Python debugger.

For example

m = torchvision.models.resnet18()
mm = torchvision.models.feature_extraction.create_feature_extractor(m, ['layer2'])
print(mm.code)

will give you

def forward(self, x : torch.Tensor):
    conv1 = self.conv1(x);  x = None
    bn1 = self.bn1(conv1);  conv1 = None
    relu = self.relu(bn1);  bn1 = None
    maxpool = self.maxpool(relu);  relu = None
    layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
    layer1_0_bn1 = getattr(self.layer1, "0").bn1(layer1_0_conv1);  layer1_0_conv1 = None
    layer1_0_relu = getattr(self.layer1, "0").relu(layer1_0_bn1);  layer1_0_bn1 = None
    layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_relu);  layer1_0_relu = None
    layer1_0_bn2 = getattr(self.layer1, "0").bn2(layer1_0_conv2);  layer1_0_conv2 = None
    add = layer1_0_bn2 + maxpool;  layer1_0_bn2 = maxpool = None
    layer1_0_relu_1 = getattr(self.layer1, "0").relu(add);  add = None
    layer1_1_conv1 = getattr(self.layer1, "1").conv1(layer1_0_relu_1)
    layer1_1_bn1 = getattr(self.layer1, "1").bn1(layer1_1_conv1);  layer1_1_conv1 = None
    layer1_1_relu = getattr(self.layer1, "1").relu(layer1_1_bn1);  layer1_1_bn1 = None
    layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_relu);  layer1_1_relu = None
    layer1_1_bn2 = getattr(self.layer1, "1").bn2(layer1_1_conv2);  layer1_1_conv2 = None
    add_1 = layer1_1_bn2 + layer1_0_relu_1;  layer1_1_bn2 = layer1_0_relu_1 = None
    layer1_1_relu_1 = getattr(self.layer1, "1").relu(add_1);  add_1 = None
    layer2_0_conv1 = getattr(self.layer2, "0").conv1(layer1_1_relu_1)
    layer2_0_bn1 = getattr(self.layer2, "0").bn1(layer2_0_conv1);  layer2_0_conv1 = None
    layer2_0_relu = getattr(self.layer2, "0").relu(layer2_0_bn1);  layer2_0_bn1 = None
    layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_relu);  layer2_0_relu = None
    layer2_0_bn2 = getattr(self.layer2, "0").bn2(layer2_0_conv2);  layer2_0_conv2 = None
    layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(layer1_1_relu_1);  layer1_1_relu_1 = None
    layer2_0_downsample_1 = getattr(getattr(self.layer2, "0").downsample, "1")(layer2_0_downsample_0);  layer2_0_downsample_0 = None
    add_2 = layer2_0_bn2 + layer2_0_downsample_1;  layer2_0_bn2 = layer2_0_downsample_1 = None
    layer2_0_relu_1 = getattr(self.layer2, "0").relu(add_2);  add_2 = None
    layer2_1_conv1 = getattr(self.layer2, "1").conv1(layer2_0_relu_1)
    layer2_1_bn1 = getattr(self.layer2, "1").bn1(layer2_1_conv1);  layer2_1_conv1 = None
    layer2_1_relu = getattr(self.layer2, "1").relu(layer2_1_bn1);  layer2_1_bn1 = None
    layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_relu);  layer2_1_relu = None
    layer2_1_bn2 = getattr(self.layer2, "1").bn2(layer2_1_conv2);  layer2_1_conv2 = None
    add_3 = layer2_1_bn2 + layer2_0_relu_1;  layer2_1_bn2 = layer2_0_relu_1 = None
    layer2_1_relu_1 = getattr(self.layer2, "1").relu(add_3);  add_3 = None
    return {'layer2': layer2_1_relu_1}

which is what is executed by Python, and what the Python interpreter will execute

@vadimkantorov
Copy link
Author

vadimkantorov commented Sep 8, 2021

Well, it's great to have, but compared to just stepping into resnet.py, this is not the same :) jk

I agree this is a great, powerful feature allowing to use a lot backbones and it's good that it unblocks this issue, but I don't see reasons for not impoving the basic ResNet by moving flatten into an attribute and maybe transforming the quantized resnet into sequential

@fmassa
Copy link
Member

fmassa commented Sep 8, 2021

I think the key idea here is that we are providing a single (generic?) solution that handle many more use-cases than what was possible before. Overriding modules with nn.Identity() was a hacky solution that didn't work most of the time (only inside nn.Sequential-style models), and can lead to confusion / silent bugs in many cases, specially to newer users I think.

Having the forward implemented for the models (even if it is a straight Sequential) can still be beneficial for users as it allows for the exact same things you have been advocating for in your past messages (namely seeing the execution code, more easily putting breakpoints / prints, etc).

Still, users are free to use what they think best fit their needs.

@vadimkantorov
Copy link
Author

vadimkantorov commented Sep 8, 2021

I was of course not talking of "newer users". Of course these hacks are for someone who already understands the backbone model source code and looks to avoid boiler-plate of copy-pasting the forward. This is no more "advanced" as the recommended individual model component reinitialization (e.g. model.roi_heads.box_predictor) as found in https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

In-the-wild surgery by resetting model.fc = nn.Identity(): https://lernapparat.de/resnet-how-many-models/

@vadimkantorov
Copy link
Author

One more reason for models to be nn.Sequential whenever possible: https://pytorch.org/docs/stable/checkpoint.html?highlight=checkpoint_sequential#torch.utils.checkpoint.checkpoint_sequential

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants