Skip to content

Document all pre-trained Classification weights #6036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,6 @@ def inject_weight_metadata(app, what, name, obj, options, lines):

for field in obj:
lines += [f"**{str(field)}**:", ""]
if field == obj.DEFAULT:
lines += [f"This weight is also available as ``{obj.__name__}.DEFAULT``.", ""]

table = []

Expand All @@ -349,7 +347,12 @@ def inject_weight_metadata(app, what, name, obj, options, lines):

custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs
if custom_docs is not None:
lines += [custom_docs, ""]
lines += [custom_docs]

if field == obj.DEFAULT:
lines += [f"Also available as ``{obj.__name__}.DEFAULT``."]

lines += [""]

for k, v in meta_with_metrics.items():
if k in {"recipe", "license"}:
Expand All @@ -367,8 +370,8 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("")
lines.append(
f"The preprocessing/inference transforms are available at ``{str(field)}.transforms`` and "
f"perform the following operations: {field.transforms().describe()}"
f"The inference transforms are available at ``{str(field)}.transforms`` and "
f"perform the following preprocessing operations: {field.transforms().describe()}"
)
lines.append("")

Expand Down
2 changes: 1 addition & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_schema_meta_validation(model_fn):
classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")}
defaults = {
"all": {"metrics", "min_size", "num_params", "recipe"},
"models": classification_fields,
"models": classification_fields | {"_docs"},
"detection": {"categories", ("metrics", "box_map")},
"quantization": classification_fields | {"backend", "unquantized"},
"segmentation": {"categories", ("metrics", "miou"), ("metrics", "pixel_acc")},
Expand Down
3 changes: 3 additions & 0 deletions torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class AlexNet_Weights(WeightsEnum):
"acc@1": 56.522,
"acc@5": 79.066,
},
"_docs": """
These weights reproduce closely the results of the paper using a simplified training recipe.
""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down
5 changes: 5 additions & 0 deletions torchvision/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ def _convnext(
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
}


Expand Down
1 change: 1 addition & 0 deletions torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _densenet(
"min_size": (29, 29),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/pull/116",
"_docs": """These weights are ported from LuaTorch.""",
}


Expand Down
38 changes: 35 additions & 3 deletions torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,24 +431,26 @@ def _efficientnet_conf(

_COMMON_META: Dict[str, Any] = {
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
}


_COMMON_META_V1 = {
**_COMMON_META,
"min_size": (1, 1),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
}


_COMMON_META_V2 = {
**_COMMON_META,
"min_size": (33, 33),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
}


class EfficientNet_B0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
Expand All @@ -460,13 +462,15 @@ class EfficientNet_B0_Weights(WeightsEnum):
"acc@1": 77.692,
"acc@5": 93.532,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_B1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(
ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
Expand All @@ -478,6 +482,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
"acc@1": 78.642,
"acc@5": 94.186,
},
"_docs": """These weights are ported from the original paper.""",
},
)
IMAGENET1K_V2 = Weights(
Expand All @@ -493,13 +498,19 @@ class EfficientNet_B1_Weights(WeightsEnum):
"acc@1": 79.838,
"acc@5": 94.934,
},
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V2


class EfficientNet_B2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(
ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
Expand All @@ -511,13 +522,15 @@ class EfficientNet_B2_Weights(WeightsEnum):
"acc@1": 80.608,
"acc@5": 95.310,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_B3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(
ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
Expand All @@ -529,13 +542,15 @@ class EfficientNet_B3_Weights(WeightsEnum):
"acc@1": 82.008,
"acc@5": 96.054,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_B4_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(
ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
Expand All @@ -547,13 +562,15 @@ class EfficientNet_B4_Weights(WeightsEnum):
"acc@1": 83.384,
"acc@5": 96.594,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_B5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(
ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
Expand All @@ -565,13 +582,15 @@ class EfficientNet_B5_Weights(WeightsEnum):
"acc@1": 83.444,
"acc@5": 96.628,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_B6_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(
ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
Expand All @@ -583,13 +602,15 @@ class EfficientNet_B6_Weights(WeightsEnum):
"acc@1": 84.008,
"acc@5": 96.916,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_B7_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(
ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
Expand All @@ -601,6 +622,7 @@ class EfficientNet_B7_Weights(WeightsEnum):
"acc@1": 84.122,
"acc@5": 96.908,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand All @@ -622,6 +644,11 @@ class EfficientNet_V2_S_Weights(WeightsEnum):
"acc@1": 84.228,
"acc@5": 96.878,
},
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V1
Expand All @@ -643,12 +670,18 @@ class EfficientNet_V2_M_Weights(WeightsEnum):
"acc@1": 85.112,
"acc@5": 97.156,
},
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V1


class EfficientNet_V2_L_Weights(WeightsEnum):
# Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial(
Expand All @@ -666,6 +699,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum):
"acc@1": 85.808,
"acc@5": 97.788,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down Expand Up @@ -1036,13 +1070,11 @@ def efficientnet_v2_l(

model_urls = _ModelURLs(
{
# Weights ported from https://github.com/rwightman/pytorch-image-models/
"efficientnet_b0": EfficientNet_B0_Weights.IMAGENET1K_V1.url,
"efficientnet_b1": EfficientNet_B1_Weights.IMAGENET1K_V1.url,
"efficientnet_b2": EfficientNet_B2_Weights.IMAGENET1K_V1.url,
"efficientnet_b3": EfficientNet_B3_Weights.IMAGENET1K_V1.url,
"efficientnet_b4": EfficientNet_B4_Weights.IMAGENET1K_V1.url,
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
"efficientnet_b5": EfficientNet_B5_Weights.IMAGENET1K_V1.url,
"efficientnet_b6": EfficientNet_B6_Weights.IMAGENET1K_V1.url,
"efficientnet_b7": EfficientNet_B7_Weights.IMAGENET1K_V1.url,
Expand Down
1 change: 1 addition & 0 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ class GoogLeNet_Weights(WeightsEnum):
"acc@1": 69.778,
"acc@5": 89.530,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down
1 change: 1 addition & 0 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ class Inception_V3_Weights(WeightsEnum):
"acc@1": 77.294,
"acc@5": 93.450,
},
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down
10 changes: 10 additions & 0 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class MNASNet0_5_Weights(WeightsEnum):
"acc@1": 67.734,
"acc@5": 87.490,
},
"_docs": """These weights reproduce closely the results of the paper.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand All @@ -246,6 +247,10 @@ class MNASNet0_75_Weights(WeightsEnum):
"acc@1": 71.180,
"acc@5": 90.496,
},
"_docs": """
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V1
Expand All @@ -262,6 +267,7 @@ class MNASNet1_0_Weights(WeightsEnum):
"acc@1": 73.456,
"acc@5": 91.510,
},
"_docs": """These weights reproduce closely the results of the paper.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand All @@ -279,6 +285,10 @@ class MNASNet1_3_Weights(WeightsEnum):
"acc@1": 76.506,
"acc@5": 93.522,
},
"_docs": """
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down
6 changes: 6 additions & 0 deletions torchvision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class MobileNet_V2_Weights(WeightsEnum):
"acc@1": 71.878,
"acc@5": 90.286,
},
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
},
)
IMAGENET1K_V2 = Weights(
Expand All @@ -224,6 +225,11 @@ class MobileNet_V2_Weights(WeightsEnum):
"acc@1": 72.154,
"acc@5": 90.822,
},
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V2
Expand Down
9 changes: 9 additions & 0 deletions torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@1": 74.042,
"acc@5": 91.340,
},
"_docs": """These weights were trained from scratch by using a simple training recipe.""",
},
)
IMAGENET1K_V2 = Weights(
Expand All @@ -334,6 +335,11 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@1": 75.274,
"acc@5": 92.566,
},
"_docs": """
These weights improve marginally upon the results of the original paper by using a modified version of
TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V2
Expand All @@ -351,6 +357,9 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
"acc@1": 67.668,
"acc@5": 87.402,
},
"_docs": """
These weights improve upon the results of the original paper by using a simple training recipe.
""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down
Loading