diff --git a/hubconf.py b/hubconf.py index 28d2c5a5d01..a229ab07667 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,6 +1,7 @@ # Optional list of dependencies required by the package dependencies = ["torch"] +from torchvision.models import get_weight from torchvision.models.alexnet import alexnet from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 05b35fe87f0..a202ed625b5 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -107,7 +107,7 @@ def get_weight(name: str) -> WeightsEnum: return weights_enum.from_str(value_name) -def get_enum_from_fn(fn: Callable) -> WeightsEnum: +def _get_enum_from_fn(fn: Callable) -> WeightsEnum: """ Internal method that gets the weight enum of a specific model builder method. Might be removed after the handle_legacy_interface is removed. diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 65fe45c4cbd..fbef524b99c 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,7 +6,7 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._api import WeightsEnum, get_enum_from_fn +from .._api import WeightsEnum, _get_enum_from_fn from .._utils import IntermediateLayerGetter, handle_legacy_interface @@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: @handle_legacy_interface( weights=( "pretrained", - lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), ), ) def resnet_fpn_backbone( @@ -177,7 +177,7 @@ def _validate_trainable_layers( @handle_legacy_interface( weights=( "pretrained", - lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), ), ) def mobilenet_backbone(