Skip to content

[RFC] Registration mechanism for models #6330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Jul 28, 2022

Fixes #1143

Note: This PR is meant to document the RFC with code examples and it's not mean to be merged. Some tests are failing because they are not compatible with the mechanism. This will be resolved on the follow up PRs that will actually introduce the mechanism to TorchVision.

Objective

Create a model registration mechanism responsible for:

  • Listing the names of all available models under a specific module
  • Getting a model (initializing) from its name
  • Getting the weight enum of the model from its name or model builder
  • Registering models under a specific name

Background

As discussed at #5088, a model registration mechanism is needed to be able to quickly list the available modules.

Currently we use hacky workarounds in order to list our models:

def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models
return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]

The same applies for initializing models from their names:

print("Creating model")
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device)

backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)

Overview

Offer a mechanism that stores all models on a single private global dictionary along with public getter/listing methods located under torchvision.models that will allow users to interact with the mechanism. This proposal aligns closely with the registration mechanism proposed at the Datasets V2 API, offering a similar user experience.

List models

Listing all available models:

>>> torchvision.models.list_models()
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', 'quantized_mobilenet_v3_large', ...]

Listing available models from a specific module:

>>> torchvision.models.list_models(module=torchvision.models)
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', ...]
>>> torchvision.models.list_models(module=torchvision.models.quantization)
['quantized_mobilenet_v3_large', ...]

Get models

Getting an initialized model with pre-trained weights from its name:

>>> torchvision.models.get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
QuantizableMobileNetV3(
  (features): Sequential(
   ....
   )
)

Get weights

Getting the Weight Enum class of a model from its name:

>>> torchvision.models.get_model_weights("quantized_mobilenet_v3_large")
<enum 'MobileNet_V3_Large_QuantizedWeights'>

Getting the Weight Enum class of a model from its callable (model builder):

>>> torchvision.models.get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
<enum 'MobileNet_V3_Large_QuantizedWeights'>

Register models

To register a model we use a special decorator and provide an arbitrary name:

@register_model(name="mobilenet_v3_large")
def mobilenet_v3_large(
    *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
    pass

@register_model(name="quantized_mobilenet_v3_large")
def mobilenet_v3_large(
    *,
    weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV3:
    pass

If you are registering a method under its method name, you can omit passing a string. The following call will register the method under the name "some_model":

@register_model()
def some_model():
    pass

Registering a model under an existing name throws an error.

Other details

  • We will keep the registration mechanism private. Perhaps on the future we can consider allowing users to use it to register custom models but that's not part of this RFC.
  • Unfortunately the names of the methods under torchvision.models.quantization conflict with those under torchvision.models. As a result we need to offer the ability to register them under arbitrary names that resolve the conflicts.
  • The current registration mechanism will throw an error if one tries to register a model under an existing name. It's possible to extend the mechanism to allow overwrites (similar to ClassyVision) but given that the mechanism is private we don't have an immediate need of this feature.

Alternatives Considered

We considered offering separate registration APIs per submodule (aka one for torchvision.models, one for torchvision.models.detection etc). This would resolve the issue of name conflicts across submodules but would lead to having separate public methods per package. Moreover this would lead to an awkward design since methods like torchvision.models.get_weight() already support all model subpackages.

@datumbox datumbox requested a review from NicolasHug July 28, 2022 18:09
@datumbox datumbox changed the title [RFC] Registration mechanism for models [RFC] [NOMERGE] Registration mechanism for models Jul 28, 2022
@datumbox datumbox marked this pull request as draft July 28, 2022 18:10
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @datumbox , just question from me for now

Some wasys we wanted to use the registration mechanism were:

  • define which weight correspond to pretrained=True so we can document this automatically
  • get the available weights from a model name / builder

Do we intend to plan for this now, or later?

@datumbox
Copy link
Contributor Author

@NicolasHug Thanks for the feedback.

define which weight correspond to pretrained=True so we can document this automatically

I don't think this is something that the registration mechanism should handle because the pretrained=True will go away in a few versions. I think this is still useful but probably needs to be implemented on a different annotation than register_model.

get the available weights from a model name / builder

Yes we can do that. We already have a private method that can achieve exactly this called _get_enum_from_fn() but we can make one public using the string names.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is something that the registration mechanism should handle because the pretrained=True will go away in a few versions

In my mind, it doesn't really matter when pretrained=True will be fully removed. This kind of info will be relevant and useful even when it's completely removed. For example for people looking at code (e.g. from past papers implementations) that still use this syntax, they would need to know which weight it corresponds to.

We can do it in another decorator / registration mechanism, but it looks like it might have some overlap with this one.

about the "quantized_xyz" name: it's a bit of a shame that the name doesn't match the model builder's name, but as discussed earlier I don't think there's a much better way that what you did here.

@datumbox datumbox changed the title [RFC] [NOMERGE] Registration mechanism for models [RFC] Registration mechanism for models Jul 29, 2022
@datumbox
Copy link
Contributor Author

Thanks everyone for their feedback and input. I will close this PR and send a new one with all the proposed changes. This way we will keep this live RFC doc easy to read.

@datumbox datumbox closed this Jul 29, 2022
@datumbox datumbox deleted the models/registration branch July 29, 2022 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Is there a plan to add a registration mechanism to torchvision.models?
5 participants