-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the most frequent cases #3331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for the proposal. Replacing torch.flatten with nn.Flatten seems it won't break anything but could you describe a bit more your use-case and how it will help you achieve what you want? Concerning doing a major refactoring and inheriting from Sequential, this might have side-effects on all the other models that depend on resnet (segmentation, object detection etc), so I'm not sure if it can be done in a backward-compatible manner. Note also that replacing the forward is not something we can do due to how Quantization works: vision/torchvision/models/quantization/resnet.py Lines 93 to 100 in 6116812
@fmassa Let me know if you see any problems about replacing flatten. |
I would be fine replacing So from this perspective, there is limited value in replacing |
Oh, I didn't realize:
If this becomes fixed, I guess inheriting from nn.Sequential should be fine. My real usecase: I'm working with a codebase https://github.com/ajabri/videowalk/blob/master/code/resnet.py#L41 that had to reimplement resnet forward because it wants to do simple model surgery |
@vadimkantorov model surgery in PyTorch generally requires re-writing forward or other parts of the model, except for |
Of course. I understand that in general it's the case, but if it can be made simpler, I think it should be made simpler even if in general case full re-writing is required. In this particular case, it seems that because of quantization scripting limitations deriving from nn.Sequential is a no-go, but when that's solved it could be nice to derive ResNet from nn.Sequential. I'll rename thie issue accordingly |
It seems that deriving from nn.Sequential would also be good from DeepSpeed-compat standpoint: pytorch/pytorch#51574 For quantization case, couldn't it just insert quant/dequant-calling modules in the beginning and the end of sequential? then everything should work and no modification of forward would be needed |
They are already modules:
One would just be required to modify the constructor and insert them before conv1 and after layer4 |
@fmassa Even if we just replace torch.flatten by self.flatten, it already simplifies model surgery by |
@vadimkantorov I think we should revisit how one does model surgery in PyTorch. In #3597 I have a prototype feature (which I'll be updating soon) that has as a much more robust and generic way of providing feature extraction As it relies on FX, it and quantization is also relying on FX for extracting the graph, I think this could be a good compromise which would address all your points I believe |
I think deleting some upper layers is only tangentially related to intermediate feature extraction, as with I propose to still do |
@vadimkantorov the magic with the FX-based approach is that if you specify any part of the model, you can actually delete (by removing the compute and the parameters) of the end of the model. So it is actually a generalization of what you are proposing |
I understand that it also solves this case :) But I think both are worthwhile :) (And maybe even figuring out how to derive from nn.Sequential for DeepSpeed benefits as well, but that's separate) And probably the FX would still forward through the "layers-to-be-removed", and the large fully-connected layer can be quite costly |
@vadimkantorov we merged #4302, which should provide a generic way of performing model surgery. Could you give this a try and provide us feedback if it is enough for your use-cases? |
Very simple: remove average pooling, remove the last fc layer. I feel that having to plug generic FX / rewriting for this is an overkill |
@vadimkantorov the thing you just mentioned can be done in one line model_surgery = create_feature_extractor(model, ['layer4']) or more generically nodes, _ = get_graph_node_names(model)
model_surgery = create_feature_extractor(model, nodes[-3]) |
I don't even need the fc layer to run. Will it still run with it? Or will it do DCE? I mean, I understand that FX is also a way to achieve this, but just being able to |
The FC layer (and its parameters) will be DCEd from the graph and won't be executed, so this is taken care for you. |
I mean it's great to have a generic solution, but having an equivalent much simpler way (when possible) that users already know how to use is also a win Then it also probably will be less debuggable. E.g. can I put a breakpoint in transformed code? |
The transformed code can be printed etc, and is executed as standard Python code by the Python interpreter, so you can jump to every line of it in a Python debugger. For example m = torchvision.models.resnet18()
mm = torchvision.models.feature_extraction.create_feature_extractor(m, ['layer2'])
print(mm.code) will give you def forward(self, x : torch.Tensor):
conv1 = self.conv1(x); x = None
bn1 = self.bn1(conv1); conv1 = None
relu = self.relu(bn1); bn1 = None
maxpool = self.maxpool(relu); relu = None
layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
layer1_0_bn1 = getattr(self.layer1, "0").bn1(layer1_0_conv1); layer1_0_conv1 = None
layer1_0_relu = getattr(self.layer1, "0").relu(layer1_0_bn1); layer1_0_bn1 = None
layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_relu); layer1_0_relu = None
layer1_0_bn2 = getattr(self.layer1, "0").bn2(layer1_0_conv2); layer1_0_conv2 = None
add = layer1_0_bn2 + maxpool; layer1_0_bn2 = maxpool = None
layer1_0_relu_1 = getattr(self.layer1, "0").relu(add); add = None
layer1_1_conv1 = getattr(self.layer1, "1").conv1(layer1_0_relu_1)
layer1_1_bn1 = getattr(self.layer1, "1").bn1(layer1_1_conv1); layer1_1_conv1 = None
layer1_1_relu = getattr(self.layer1, "1").relu(layer1_1_bn1); layer1_1_bn1 = None
layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_relu); layer1_1_relu = None
layer1_1_bn2 = getattr(self.layer1, "1").bn2(layer1_1_conv2); layer1_1_conv2 = None
add_1 = layer1_1_bn2 + layer1_0_relu_1; layer1_1_bn2 = layer1_0_relu_1 = None
layer1_1_relu_1 = getattr(self.layer1, "1").relu(add_1); add_1 = None
layer2_0_conv1 = getattr(self.layer2, "0").conv1(layer1_1_relu_1)
layer2_0_bn1 = getattr(self.layer2, "0").bn1(layer2_0_conv1); layer2_0_conv1 = None
layer2_0_relu = getattr(self.layer2, "0").relu(layer2_0_bn1); layer2_0_bn1 = None
layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_relu); layer2_0_relu = None
layer2_0_bn2 = getattr(self.layer2, "0").bn2(layer2_0_conv2); layer2_0_conv2 = None
layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(layer1_1_relu_1); layer1_1_relu_1 = None
layer2_0_downsample_1 = getattr(getattr(self.layer2, "0").downsample, "1")(layer2_0_downsample_0); layer2_0_downsample_0 = None
add_2 = layer2_0_bn2 + layer2_0_downsample_1; layer2_0_bn2 = layer2_0_downsample_1 = None
layer2_0_relu_1 = getattr(self.layer2, "0").relu(add_2); add_2 = None
layer2_1_conv1 = getattr(self.layer2, "1").conv1(layer2_0_relu_1)
layer2_1_bn1 = getattr(self.layer2, "1").bn1(layer2_1_conv1); layer2_1_conv1 = None
layer2_1_relu = getattr(self.layer2, "1").relu(layer2_1_bn1); layer2_1_bn1 = None
layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_relu); layer2_1_relu = None
layer2_1_bn2 = getattr(self.layer2, "1").bn2(layer2_1_conv2); layer2_1_conv2 = None
add_3 = layer2_1_bn2 + layer2_0_relu_1; layer2_1_bn2 = layer2_0_relu_1 = None
layer2_1_relu_1 = getattr(self.layer2, "1").relu(add_3); add_3 = None
return {'layer2': layer2_1_relu_1} which is what is executed by Python, and what the Python interpreter will execute |
Well, it's great to have, but compared to just stepping into resnet.py, this is not the same :) jk I agree this is a great, powerful feature allowing to use a lot backbones and it's good that it unblocks this issue, but I don't see reasons for not impoving the basic ResNet by moving flatten into an attribute and maybe transforming the quantized resnet into sequential |
I think the key idea here is that we are providing a single (generic?) solution that handle many more use-cases than what was possible before. Overriding modules with Having the Still, users are free to use what they think best fit their needs. |
I was of course not talking of "newer users". Of course these hacks are for someone who already understands the backbone model source code and looks to avoid boiler-plate of copy-pasting the forward. This is no more "advanced" as the recommended individual model component reinitialization (e.g. In-the-wild surgery by resetting |
One more reason for models to be nn.Sequential whenever possible: https://pytorch.org/docs/stable/checkpoint.html?highlight=checkpoint_sequential#torch.utils.checkpoint.checkpoint_sequential |
Currently In https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L243:
If it instead used
x = self.flatten(x)
, then it would simplify model surgery:del model.avgpool, model.flatten, model.fc
. Also in this case the class can just derive from Sequential and use OrderedDict to pass submodules (like in https://discuss.pytorch.org/t/ux-mix-of-nn-sequential-and-nn-moduledict/104724/2?u=vadimkantorov), this would preserve checkpoint compat as well. The methodforward
could then be removedThe text was updated successfully, but these errors were encountered: