diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 4961d7def50..595fdf88b66 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -51,6 +51,19 @@ def verify(cls, obj: Any) -> Any: if type(obj) is str: obj = cls.from_str(obj) elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry): + # Allowing WeightEntry to pass-through is unexpected IMHO. + # Since verify() is a factory method, we would expect it to + # return only Weights instances, not WeightsEntry instances. + # This could lead to bugs if the caller is unaware of that + # subtly hidden detail. + # From what I understand the motivation for allowing WeightEntry + # is for user convenience when calling model builder functions + # like resnet18(). IMO for a first version, it's perfectly + # reasonable to expect users to write a small wrapper class that + # inherits from Weights instead. Maybe this could be re-visited + # in the future, but for a first version, I think we should try + # to be conservative: if there's a decent way of achieving a + # somewhat rare use-case, it's not worth worrying about for now. raise TypeError( f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." ) @@ -63,6 +76,11 @@ def from_str(cls, value: str) -> "Weights": return v raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") + # Nit: I would expect 'state_dict' to be an attribute, not a method. + # I understand we want to avoid + # model.load_state_dict(weights.load_state_dict(progress=progress)) + # which reads strangely. + # Maybe `get_state_dict()` would better convey the idea that this is a procedure? def state_dict(self, progress: bool) -> OrderedDict: return load_state_dict_from_url(self.url, progress=progress) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index db0c742e48d..76a7e3c716e 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -26,6 +26,8 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None: + # Could you explain the comments below a bit more? Why are dots no longer allowed? + # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used @@ -67,6 +69,9 @@ def _densenet( "size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, + # Is this worth adding literally this comment as the value? e.g. + # "recipe": "weights ported from LuaTorch", + # ? "recipe": None, # weights ported from LuaTorch } diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 66e584eec8f..cae7c0d2aac 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -147,11 +147,25 @@ def _fasterrcnn_mobilenet_v3_large_fpn( return model +# Unfortunately, the addition and removal of parameters here is BC breaking for +# users that are using positional arguments. +# +# It might be possible to preserve BC by allowing: +# weights to be a bool - it maps to pretrained +# weights_backbone to be a bool - it maps to progress +# progress to be an int - it maps to num_classes +# num_classes to be a bool: it maps to pretrained_backbones +# And to raise a warning in all these cases. +# +# That will complicate the code though, we should only do this if we are 100% +# sure to remove these deprecated behaviours in 1 or 2 versions. +# Also, this only works because the types of the new vs old parameters order +# don't overlap. def fasterrcnn_mobilenet_v3_large_fpn( weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, weights_backbone: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, - num_classes: int = 91, + num_classes: int = 91, # just wondering why num_classes is a parameter here instead of being part of kwargs like for other models? trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: @@ -189,6 +203,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ) -> FasterRCNN: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") + # These long weight names are a bit of a mouthful lol + # I don't have a better proposal TBH. + # I think the complexity of those weight names can be greatly reduced by: + # - good docs (which I know is part of the future work) + # - good "preset values" like "pretrained" or "latest" or whatever is easy to remember for the user, as suggested in another comment weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) if "pretrained_backbone" in kwargs: diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 48b6640e044..347485c04cd 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -48,6 +48,14 @@ class KeypointRCNNResNet50FPNWeights(Weights): ) +# This is BC breaking too, and unfortunately I'm not sure we can raise warnings +# like in fasterrcnn_mobilenet_v3_large_fpn because we wouldn't be able +# to differentiate num_keypoints from num_classes. +# +# I'm not sure what the best course of action could be here +# +# Not a solution, but related: these issues are a strong argument in favour of +# forcing keyword-only parameters in new APIs. def keypointrcnn_resnet50_fpn( weights: Optional[KeypointRCNNResNet50FPNWeights] = None, weights_backbone: Optional[ResNet50Weights] = None, diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 5691fb81b03..44507885800 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -37,8 +37,10 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: if "transform_input" not in kwargs: + # Similar to the comment about num_classes, I think we should raise and error if the user manually specified transform_input=False kwargs["transform_input"] = True if original_aux_logits: + # Nit: maybe move this as an `else` clause below warnings.warn( "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" ) @@ -50,6 +52,7 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, if weights is not None: model.load_state_dict(weights.state_dict(progress=progress)) + # I understand this is present in the current code, just curious why this is needed? if not original_aux_logits: model.aux_logits = False model.aux1 = None # type: ignore[assignment] diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 5f289c755c0..67a12d60856 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -41,6 +41,16 @@ def _resnet( **kwargs: Any, ) -> ResNet: if weights is not None: + # I remember Francisco raising concern about overriding num_class. + # IMHO this is fine, I would consider this to be a case of + # "we have param A and param B, param B's default depends on the value of A." + # which is a very common pattern. + # However, in order to prevent users to do wrong things, I think we could raise + # an error if the user specified an incorrect 'num_classes': + if "num_classes" in kwargs and kwargs["num_classes"] != len(weights.meta["categories"]): + raise ValueError( + f"Oops, you specified num_classes={blah} but this is incompatible with the pre-trained weights which support {blop} classes." + ) kwargs["num_classes"] = len(weights.meta["categories"]) model = ResNet(block, layers, **kwargs) @@ -51,6 +61,7 @@ def _resnet( return model +# Nit: upper-case for global vars? _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} @@ -231,11 +242,28 @@ class WideResNet101_2Weights(Weights): ) +# Maybe we should rename kwargs into model_kwargs to be more explicit? This is not a BC-breaking change I believe. def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + # Considering we're showing this warning literally everywhere, it might be worth writing a helper for it + # This would make sure the warning message is consistent, avoid code duplication, etc. if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") + # Nit: "argument" should probably be "parameter", since the term + # "argument" refers to the *value* of the parameter at call time. + warnings.warn("The 'pretrained' parameter is deprecated, please use the 'weights' parameter instead.") weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + # To preserve full BC with the current resnet18, weights coudl also accept + # a boolean and this would issue deprecation warning as well. This is for + # cases where resnet18 is called with resnet18(True) or resnet18(False) + # something like this: + if isinstance(weights, bool): + warnings.warn("The 'pretrained' parameter is deprecated, please use the 'weights' parameter instead.") + weights = ResNet18Weights.ImageNet1K_RefV1 if weights else None + + # I remember we discussed this before: before releasing this new API, we + # mght want to try to allow users to pass simple stuff like e.g. + # resnet18(weights='pretrained') or resnet18(weights='latest') + weights = ResNet18Weights.verify(weights) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)