Skip to content

Ability to change the retinanet model to be modified after training #6614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Monk5088 opened this issue Sep 20, 2022 · 2 comments
Open

Ability to change the retinanet model to be modified after training #6614

Monk5088 opened this issue Sep 20, 2022 · 2 comments

Comments

@Monk5088
Copy link

Monk5088 commented Sep 20, 2022

🚀 The feature

After training the retinanet model, we are not able to change the number of classes for the next training session. For example, if the retinanet model is trained on 56 classes, the classifier subnet of Retinanet outputs 56 hot encoded vector, how can we use the same model weights for the next dataset which only has 40 classes of the previous problem?
My issue is similar I trained my retinanet on 3 classes, and now I want to use the same model weights but for a new dataset that has only had the first 2 classes of the previous dataset, when I try to do torch.load(model.state_dict(), PATH) it throws me a classifier mismatch:


/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1496         if len(error_msgs) > 0:
   1497             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1499         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1500 

RuntimeError: Error(s) in loading state_dict for RetinaNet:
	size mismatch for classifier.3.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([2, 128, 3, 3]).
	size mismatch for classifier.3.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).

Motivation, pitch

I'm working on object detection using retinanet and I face the above-mentioned issue.

Alternatives

No response

Additional context

No response

cc @datumbox

@datumbox
Copy link
Contributor

@Monk5088 there is a similar discussion for FCOS at #5932. You might want to adjust this solution of this comment in your case: #5932 (comment)

@Monk5088
Copy link
Author

Monk5088 commented Sep 21, 2022

Yeah that solution in #5932 looks like it would work but in my case it's a classic Retinanet architecture with Resnet 34 backbone, so i need to remove the normal and constant line in the code, also is there any other modification that i need to do @datumbox ?
Here is the link for my implementation of retinanet:
https://github.com/ChristianMarzahl/ObjectDetection/blob/master/object_detection_fastai/models/RetinaNet.py .
It would be really helpful if you could mention the changes i need to make for the custom FCOS model.
Thanks and regards,
Harshit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants