Skip to content

NOMRG Comments on prototype weights and model builders #4937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def verify(cls, obj: Any) -> Any:
if type(obj) is str:
obj = cls.from_str(obj)
elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry):
# Allowing WeightEntry to pass-through is unexpected IMHO.
# Since verify() is a factory method, we would expect it to
# return only Weights instances, not WeightsEntry instances.
# This could lead to bugs if the caller is unaware of that
# subtly hidden detail.
# From what I understand the motivation for allowing WeightEntry
# is for user convenience when calling model builder functions
# like resnet18(). IMO for a first version, it's perfectly
# reasonable to expect users to write a small wrapper class that
# inherits from Weights instead. Maybe this could be re-visited
# in the future, but for a first version, I think we should try
# to be conservative: if there's a decent way of achieving a
# somewhat rare use-case, it's not worth worrying about for now.
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
)
Expand All @@ -63,6 +76,11 @@ def from_str(cls, value: str) -> "Weights":
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")

# Nit: I would expect 'state_dict' to be an attribute, not a method.
# I understand we want to avoid
# model.load_state_dict(weights.load_state_dict(progress=progress))
# which reads strangely.
# Maybe `get_state_dict()` would better convey the idea that this is a procedure?
def state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress)

Expand Down
5 changes: 5 additions & 0 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@


def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None:
# Could you explain the comments below a bit more? Why are dots no longer allowed?

# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
Expand Down Expand Up @@ -67,6 +69,9 @@ def _densenet(
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
# Is this worth adding literally this comment as the value? e.g.
# "recipe": "weights ported from LuaTorch",
# ?
"recipe": None, # weights ported from LuaTorch
}

Expand Down
21 changes: 20 additions & 1 deletion torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,25 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
return model


# Unfortunately, the addition and removal of parameters here is BC breaking for
# users that are using positional arguments.
#
# It might be possible to preserve BC by allowing:
# weights to be a bool - it maps to pretrained
# weights_backbone to be a bool - it maps to progress
# progress to be an int - it maps to num_classes
# num_classes to be a bool: it maps to pretrained_backbones
# And to raise a warning in all these cases.
#
# That will complicate the code though, we should only do this if we are 100%
# sure to remove these deprecated behaviours in 1 or 2 versions.
# Also, this only works because the types of the new vs old parameters order
# don't overlap.
def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: int = 91, # just wondering why num_classes is a parameter here instead of being part of kwargs like for other models?
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
Expand Down Expand Up @@ -189,6 +203,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
# These long weight names are a bit of a mouthful lol
# I don't have a better proposal TBH.
# I think the complexity of those weight names can be greatly reduced by:
# - good docs (which I know is part of the future work)
# - good "preset values" like "pretrained" or "latest" or whatever is easy to remember for the user, as suggested in another comment
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
Expand Down
8 changes: 8 additions & 0 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class KeypointRCNNResNet50FPNWeights(Weights):
)


# This is BC breaking too, and unfortunately I'm not sure we can raise warnings
# like in fasterrcnn_mobilenet_v3_large_fpn because we wouldn't be able
# to differentiate num_keypoints from num_classes.
#
# I'm not sure what the best course of action could be here
#
# Not a solution, but related: these issues are a strong argument in favour of
# forcing keyword-only parameters in new APIs.
def keypointrcnn_resnet50_fpn(
weights: Optional[KeypointRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True,
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
# Similar to the comment about num_classes, I think we should raise and error if the user manually specified transform_input=False
kwargs["transform_input"] = True
if original_aux_logits:
# Nit: maybe move this as an `else` clause below
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
Expand All @@ -50,6 +52,7 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True,

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
# I understand this is present in the current code, just curious why this is needed?
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
Expand Down
30 changes: 29 additions & 1 deletion torchvision/prototype/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def _resnet(
**kwargs: Any,
) -> ResNet:
if weights is not None:
# I remember Francisco raising concern about overriding num_class.
# IMHO this is fine, I would consider this to be a case of
# "we have param A and param B, param B's default depends on the value of A."
# which is a very common pattern.
# However, in order to prevent users to do wrong things, I think we could raise
# an error if the user specified an incorrect 'num_classes':
if "num_classes" in kwargs and kwargs["num_classes"] != len(weights.meta["categories"]):
raise ValueError(
f"Oops, you specified num_classes={blah} but this is incompatible with the pre-trained weights which support {blop} classes."
)
kwargs["num_classes"] = len(weights.meta["categories"])

model = ResNet(block, layers, **kwargs)
Expand All @@ -51,6 +61,7 @@ def _resnet(
return model


# Nit: upper-case for global vars?
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}


Expand Down Expand Up @@ -231,11 +242,28 @@ class WideResNet101_2Weights(Weights):
)


# Maybe we should rename kwargs into model_kwargs to be more explicit? This is not a BC-breaking change I believe.
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
# Considering we're showing this warning literally everywhere, it might be worth writing a helper for it
# This would make sure the warning message is consistent, avoid code duplication, etc.
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
# Nit: "argument" should probably be "parameter", since the term
# "argument" refers to the *value* of the parameter at call time.
warnings.warn("The 'pretrained' parameter is deprecated, please use the 'weights' parameter instead.")
weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

# To preserve full BC with the current resnet18, weights coudl also accept
# a boolean and this would issue deprecation warning as well. This is for
# cases where resnet18 is called with resnet18(True) or resnet18(False)
# something like this:
if isinstance(weights, bool):
warnings.warn("The 'pretrained' parameter is deprecated, please use the 'weights' parameter instead.")
weights = ResNet18Weights.ImageNet1K_RefV1 if weights else None

# I remember we discussed this before: before releasing this new API, we
# mght want to try to allow users to pass simple stuff like e.g.
# resnet18(weights='pretrained') or resnet18(weights='latest')

weights = ResNet18Weights.verify(weights)

return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
Expand Down