From 0da2e769210be8f21145ed8ec735bc5e495e9611 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 16:41:41 +0100 Subject: [PATCH 1/3] Add regnet_y_16gf and regnet_y_32gf from swag --- torchvision/models/regnet.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 72093686d84..429d74fd4c1 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -412,6 +412,15 @@ def _regnet( "interpolation": InterpolationMode.BILINEAR, } +_COMMON_SWAG_META = { + **_COMMON_META, + "publication_year": 2022, + "size": (384, 384), + "recipe": "https://github.com/facebookresearch/SWAG", + "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", + "interpolation": InterpolationMode.BICUBIC, +} + class RegNet_Y_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( @@ -566,6 +575,17 @@ class RegNet_Y_16GF_Weights(WeightsEnum): "acc@5": 96.328, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth", + transforms=partial(ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC), + meta={ + **_COMMON_SWAG_META, + "num_params": 83590140, + # still mock + "acc@1": 86.02, + "acc@5": 98.05, + }, + ) DEFAULT = IMAGENET1K_V2 @@ -592,6 +612,17 @@ class RegNet_Y_32GF_Weights(WeightsEnum): "acc@5": 96.498, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth", + transforms=partial(ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC), + meta={ + **_COMMON_SWAG_META, + "num_params": 145046770, + # still mock + "acc@1": 86.83, + "acc@5": 98.36, + }, + ) DEFAULT = IMAGENET1K_V2 From da7a18ff726e407af6d806d2cc36ff3eae9bba23 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 16:46:01 +0100 Subject: [PATCH 2/3] Format with ufmt --- torchvision/models/regnet.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 429d74fd4c1..407a71ba525 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -577,7 +577,9 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ) IMAGENET1K_SWAG_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth", - transforms=partial(ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_SWAG_META, "num_params": 83590140, @@ -614,7 +616,9 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ) IMAGENET1K_SWAG_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth", - transforms=partial(ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_SWAG_META, "num_params": 145046770, From 913be93b49c7897251033f34e0c29f68e171193e Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 1 Apr 2022 20:33:25 +0100 Subject: [PATCH 3/3] Add the experiment accuracy --- torchvision/models/regnet.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 407a71ba525..0748ee0460d 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -583,9 +583,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): meta={ **_COMMON_SWAG_META, "num_params": 83590140, - # still mock - "acc@1": 86.02, - "acc@5": 98.05, + "acc@1": 86.012, + "acc@5": 98.054, }, ) DEFAULT = IMAGENET1K_V2 @@ -622,9 +621,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): meta={ **_COMMON_SWAG_META, "num_params": 145046770, - # still mock - "acc@1": 86.83, - "acc@5": 98.36, + "acc@1": 86.838, + "acc@5": 98.362, }, ) DEFAULT = IMAGENET1K_V2