Skip to content

Expose get_weight to Torch Hub #6026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Optional list of dependencies required by the package
dependencies = ["torch"]

from torchvision.models import get_weight
from torchvision.models.alexnet import alexnet
from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name)


def get_enum_from_fn(fn: Callable) -> WeightsEnum:
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"""
Internal method that gets the weight enum of a specific model builder method.
Might be removed after the handle_legacy_interface is removed.
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool

from .. import mobilenet, resnet
from .._api import WeightsEnum, get_enum_from_fn
from .._api import WeightsEnum, _get_enum_from_fn
from .._utils import IntermediateLayerGetter, handle_legacy_interface


Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
),
)
def resnet_fpn_backbone(
Expand Down Expand Up @@ -177,7 +177,7 @@ def _validate_trainable_layers(
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
),
)
def mobilenet_backbone(
Expand Down