Skip to content

add swin_s and swin_b variants and improved swin_t #6048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/models/swin_transformer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ more details about this class.
:template: function.rst

swin_t
swin_s
swin_b
8 changes: 4 additions & 4 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ and `--batch_size 64`.
### SwinTransformer
```
torchrun --nproc_per_node=8 train.py\
--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\
--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\
--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra
--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 224
```
Here `$MODEL` is one of `swin_t`, `swin_s` or `swin_b`.
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.




### ShuffleNet V2
```
torchrun --nproc_per_node=8 train.py \
Expand Down
Binary file added test/expect/ModelTester.test_swin_b_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_swin_s_expect.pkl
Binary file not shown.
132 changes: 127 additions & 5 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
__all__ = [
"SwinTransformer",
"Swin_T_Weights",
"Swin_S_Weights",
"Swin_B_Weights",
"swin_t",
"swin_s",
"swin_b",
]


Expand Down Expand Up @@ -408,9 +412,9 @@ def _swin_transformer(

class Swin_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_t-81486767.pth",
url="https://download.pytorch.org/models/swin_t-4c37bd06.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
Expand All @@ -419,11 +423,57 @@ class Swin_T_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": {
"ImageNet-1K": {
"acc@1": 81.358,
"acc@5": 95.526,
"acc@1": 81.474,
"acc@5": 95.776,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1


class Swin_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_s-30134662.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 49606258,
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.196,
"acc@5": 96.360,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1


class Swin_B_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 87768224,
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.582,
"acc@5": 96.640,
}
},
"_docs": """These weights reproduce closely the results of the paper using its training recipe.""",
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down Expand Up @@ -463,3 +513,75 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
progress=progress,
**kwargs,
)


def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_small architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.

Args:
weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.

.. autoclass:: torchvision.models.Swin_S_Weights
:members:
"""
weights = Swin_S_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
stochastic_depth_prob=0.3,
weights=weights,
progress=progress,
**kwargs,
)


def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_base architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.

Args:
weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.

.. autoclass:: torchvision.models.Swin_B_Weights
:members:
"""
weights = Swin_B_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=7,
stochastic_depth_prob=0.5,
weights=weights,
progress=progress,
**kwargs,
)