diff --git a/docs/source/models/swin_transformer.rst b/docs/source/models/swin_transformer.rst index b8726d71d2a..2f67b0d5274 100644 --- a/docs/source/models/swin_transformer.rst +++ b/docs/source/models/swin_transformer.rst @@ -23,3 +23,5 @@ more details about this class. :template: function.rst swin_t + swin_s + swin_b diff --git a/references/classification/README.md b/references/classification/README.md index 9eb95fd00e9..da30159542b 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -228,14 +228,14 @@ and `--batch_size 64`. ### SwinTransformer ``` torchrun --nproc_per_node=8 train.py\ ---model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\ ---bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\ ---lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\ ---clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra +--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 224 ``` +Here `$MODEL` is one of `swin_t`, `swin_s` or `swin_b`. Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. + + ### ShuffleNet V2 ``` torchrun --nproc_per_node=8 train.py \ diff --git a/test/expect/ModelTester.test_swin_b_expect.pkl b/test/expect/ModelTester.test_swin_b_expect.pkl new file mode 100644 index 00000000000..d807ca3ed15 Binary files /dev/null and b/test/expect/ModelTester.test_swin_b_expect.pkl differ diff --git a/test/expect/ModelTester.test_swin_s_expect.pkl b/test/expect/ModelTester.test_swin_s_expect.pkl new file mode 100644 index 00000000000..2624dad4178 Binary files /dev/null and b/test/expect/ModelTester.test_swin_s_expect.pkl differ diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 69a0d5fd2fd..6e001c1d2dd 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -18,7 +18,11 @@ __all__ = [ "SwinTransformer", "Swin_T_Weights", + "Swin_S_Weights", + "Swin_B_Weights", "swin_t", + "swin_s", + "swin_b", ] @@ -408,9 +412,9 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_t-81486767.pth", + url="https://download.pytorch.org/models/swin_t-4c37bd06.pth", transforms=partial( - ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META, @@ -419,11 +423,57 @@ class Swin_T_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", "_metrics": { "ImageNet-1K": { - "acc@1": 81.358, - "acc@5": 95.526, + "acc@1": 81.474, + "acc@5": 95.776, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_s-30134662.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 49606258, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.196, + "acc@5": 96.360, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_B_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 87768224, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.582, + "acc@5": 96.640, } }, - "_docs": """These weights reproduce closely the results of the paper using its training recipe.""", + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", }, ) DEFAULT = IMAGENET1K_V1 @@ -463,3 +513,75 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * progress=progress, **kwargs, ) + + +def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_small architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_S_Weights + :members: + """ + weights = Swin_S_Weights.verify(weights) + + return _swin_transformer( + patch_size=4, + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + stochastic_depth_prob=0.3, + weights=weights, + progress=progress, + **kwargs, + ) + + +def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_base architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_B_Weights + :members: + """ + weights = Swin_B_Weights.verify(weights) + + return _swin_transformer( + patch_size=4, + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + stochastic_depth_prob=0.5, + weights=weights, + progress=progress, + **kwargs, + )