Skip to content

add swin_s and swin_b variants and improved swin_t #6048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 19, 2022

Conversation

jdsgomes
Copy link
Contributor

@jdsgomes jdsgomes commented May 18, 2022

Adding improved weights for swin_t and new weights for swin_s, and swin_b.

The new variants were trained with trivial augment, instead of rand augment because it is actually closer to the original implementation.

Trainning commands:

# swin_t
python -u run_with_submitit.py --timeout 3000 --ngpus 8 --nodes 1  --model swin_t --epochs 300 \
    --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0  \
    --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr \
    --lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 \
    --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 \
    --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler \
    --ra-reps 4  --val-resize-size 224 

# swin_s
python -u run_with_submitit.py --timeout 3000 --ngpus 8 --nodes 1  --model swin_s --epochs 300 \
    --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0  \
    --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr \
    --lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 \
    --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 \
    --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler \
    --ra-reps 4  --val-resize-size 224 

# swin_b
python -u run_with_submitit.py --timeout 3000 --ngpus 8 --nodes 1  --model swin_b --epochs 300 \
    --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 \ 
    --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr \
    --lr-min 0.00001 --lr-warmup-method linear  --lr-warmup-epochs 20 --lr-warmup-decay 0.01 \
    --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 \
    --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler \
    --ra-reps 4  --val-resize-size 224 

Test commands:

# swin_t
srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py \
    --model swin_t --test-only --resume $EXPERIMENTS_PATH/43599/model_299.pth --interpolation bicubic   \
    --val-resize-size 232
# Test:  Acc@1 81.474 Acc@5 95.776

# swin_s
srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py \
    --model swin_s --test-only --resume $EXPERIMENTS_PATH/43602/model_299.pth --interpolation bicubic   \
    --val-resize-size 246
# Test:  Acc@1 83.196 Acc@5 96.360

# swin_b
srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py \
    --model swin_b --test-only --resume $EXPERIMENTS_PATH/44985/model_299.pth --interpolation bicubic   \
    --val-resize-size 238
# Test:  Acc@1 83.582 Acc@5 96.640

cc @xiaohu2015 @datumbox

@jdsgomes jdsgomes marked this pull request as draft May 18, 2022 15:18
@@ -226,7 +228,9 @@ convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
convnext_large 84.414 96.976
swin_t 81.358 95.526
swin_t 81.474 95.776
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original repo Acc@1 81.2 Acc@5 95.5

@@ -226,7 +228,9 @@ convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
convnext_large 84.414 96.976
swin_t 81.358 95.526
swin_t 81.474 95.776
swin_s 83.196 96.360
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original repo acc Acc@1 83.2 Acc@5 96.2

swin_t 81.358 95.526
swin_t 81.474 95.776
swin_s 83.196 96.360
swin_b 83.582 96.640
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original repo Acc@1 83.5 Acc@5 96.5

@jdsgomes jdsgomes changed the title [WIP] add swin_s and swin_b variants add swin_s and swin_b variants and improved swin_t May 18, 2022
@jdsgomes jdsgomes marked this pull request as ready for review May 18, 2022 16:29
@datumbox
Copy link
Contributor

datumbox commented May 18, 2022

@jdsgomes Can you provide proof of the validation stats for future reference?

Here is a recent PR for what I mean: #6019 (PR description)

@jdsgomes jdsgomes marked this pull request as draft May 18, 2022 17:00
@xiaohu2015
Copy link
Contributor

@jdsgomes Good. It seems that you added some other tricks: ta_wide model-ema ra-sampler?

@jdsgomes jdsgomes marked this pull request as ready for review May 19, 2022 08:01
@datumbox
Copy link
Contributor

@jdsgomes I've pushed a tiny change on doc strings to indicate its a "similar" recipe not the same. As @xiaohu2015 highlighted, we use a mix of the paper recipe with the standard recipe of TorchVision, so this will be more factual.

My intention was to merge this prior merging the PR that removes the old models.srt (to avoid conflicts) but it seems that your PR fails on a specific test (see this) which looks related. Could you please check it out?

@jdsgomes
Copy link
Contributor Author

Thanks @datumbox for the change in the docstrings, and @xiaohu2015 you are right this is a slightly modified recipe, I should have highlighted that initially!

Looking in the error now

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @jdsgomes. Feel free to merge on Green-ish CI.

@datumbox datumbox mentioned this pull request May 19, 2022
24 tasks
@jdsgomes jdsgomes merged commit 9d9cfab into pytorch:main May 19, 2022
@xiaohu2015
Copy link
Contributor

The SwinV2 offical code has released, maybe we can also add swin_v2

@datumbox
Copy link
Contributor

@xiaohu2015 If you are up for it, by all means. I think it's possible to update this class to produce both Swin V1 and V2 models, similar to what we did with EfficientNets.

facebook-github-bot pushed a commit that referenced this pull request Jun 1, 2022
Summary:
* add swin_s and swin_b variants

* fix swin_b params

* fix n parameters and acc numbers

* adding missing acc numbers

* apply ufmt

* Updating `_docs` to reflect training recipe

* Fix exted for swin_b

Reviewed By: NicolasHug

Differential Revision: D36760946

fbshipit-source-id: 52b5056ee8b3efde4e00aebf7a7f732819d90c4f

Co-authored-by: Vasilis Vryniotis <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants