-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[RFC] Registration mechanism for models #6330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @datumbox , just question from me for now
Some wasys we wanted to use the registration mechanism were:
- define which weight correspond to
pretrained=True
so we can document this automatically - get the available weights from a model name / builder
Do we intend to plan for this now, or later?
@NicolasHug Thanks for the feedback.
I don't think this is something that the registration mechanism should handle because the
Yes we can do that. We already have a private method that can achieve exactly this called |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is something that the registration mechanism should handle because the pretrained=True will go away in a few versions
In my mind, it doesn't really matter when pretrained=True
will be fully removed. This kind of info will be relevant and useful even when it's completely removed. For example for people looking at code (e.g. from past papers implementations) that still use this syntax, they would need to know which weight it corresponds to.
We can do it in another decorator / registration mechanism, but it looks like it might have some overlap with this one.
about the "quantized_xyz"
name: it's a bit of a shame that the name doesn't match the model builder's name, but as discussed earlier I don't think there's a much better way that what you did here.
Thanks everyone for their feedback and input. I will close this PR and send a new one with all the proposed changes. This way we will keep this live RFC doc easy to read. |
Fixes #1143
Note: This PR is meant to document the RFC with code examples and it's not mean to be merged. Some tests are failing because they are not compatible with the mechanism. This will be resolved on the follow up PRs that will actually introduce the mechanism to TorchVision.
Objective
Create a model registration mechanism responsible for:
Background
As discussed at #5088, a model registration mechanism is needed to be able to quickly list the available modules.
Currently we use hacky workarounds in order to list our models:
vision/test/test_models.py
Lines 24 to 30 in 6ca9c76
The same applies for initializing models from their names:
vision/references/classification/train.py
Lines 223 to 225 in e288f6c
vision/torchvision/models/detection/backbone_utils.py
Line 112 in 6ca9c76
Overview
Offer a mechanism that stores all models on a single private global dictionary along with public getter/listing methods located under
torchvision.models
that will allow users to interact with the mechanism. This proposal aligns closely with the registration mechanism proposed at the Datasets V2 API, offering a similar user experience.List models
Listing all available models:
Listing available models from a specific module:
Get models
Getting an initialized model with pre-trained weights from its name:
Get weights
Getting the Weight Enum class of a model from its name:
Getting the Weight Enum class of a model from its callable (model builder):
Register models
To register a model we use a special decorator and provide an arbitrary name:
If you are registering a method under its method name, you can omit passing a string. The following call will register the method under the name
"some_model"
:Registering a model under an existing name throws an error.
Other details
torchvision.models.quantization
conflict with those undertorchvision.models
. As a result we need to offer the ability to register them under arbitrary names that resolve the conflicts.Alternatives Considered
We considered offering separate registration APIs per submodule (aka one for
torchvision.models
, one fortorchvision.models.detection
etc). This would resolve the issue of name conflicts across submodules but would lead to having separate public methods per package. Moreover this would lead to an awkward design since methods liketorchvision.models.get_weight()
already support all model subpackages.