Skip to content

Add MobileNet V2 #818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 28, 2019
Merged

Add MobileNet V2 #818

merged 6 commits into from
Mar 28, 2019

Conversation

fmassa
Copy link
Member

@fmassa fmassa commented Mar 27, 2019

This PR adds support for MobileNetV2.

It's been heavily based on the implementation from #625 by @tonylins

I'm currently training the model from scratch with a custom script, and I'll be uploading the weights (together with the training hyperparameters) once training finishes and I match reported accuracies.

@codecov-io
Copy link

Codecov Report

Merging #818 into master will increase coverage by 0.22%.
The diff coverage is 96.61%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master    #818      +/-   ##
=========================================
+ Coverage   51.58%   51.8%   +0.22%     
=========================================
  Files          34      35       +1     
  Lines        3342    3401      +59     
  Branches      536     545       +9     
=========================================
+ Hits         1724    1762      +38     
- Misses       1486    1497      +11     
- Partials      132     142      +10
Impacted Files Coverage Δ
torchvision/models/__init__.py 100% <100%> (ø) ⬆️
torchvision/models/mobilenet.py 96.55% <96.55%> (ø)
torchvision/datasets/folder.py 65.88% <0%> (-2.36%) ⬇️
torchvision/datasets/cifar.py 31.68% <0%> (-1.99%) ⬇️
torchvision/datasets/lsun.py 17.64% <0%> (-1.97%) ⬇️
torchvision/datasets/voc.py 18.44% <0%> (-1.95%) ⬇️
torchvision/datasets/caltech.py 19.51% <0%> (-1.63%) ⬇️
torchvision/transforms/transforms.py 81.48% <0%> (-1.36%) ⬇️
torchvision/transforms/functional.py 69.52% <0%> (-0.96%) ⬇️
torchvision/datasets/imagenet.py 21.21% <0%> (ø) ⬆️
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 27ff89f...ddaf5da. Read the comment docs.

@fmassa
Copy link
Member Author

fmassa commented Mar 28, 2019

A first run of training gave 70.8 top1 accuracy, which is ~1% lower than the expected 71.8% reported.
I'll be merging this PR as is to move further, and will update the results with a pre-trained model (and now the accompanying training hyperparameters) in a followup PR.

@fmassa fmassa merged commit a61803f into pytorch:master Mar 28, 2019
@fmassa fmassa deleted the mobilenet-v2 branch March 28, 2019 14:11
@hsparrow
Copy link

hsparrow commented Apr 1, 2019

Hi, I'm trying to understand the MobileNetV2 implementation. However, I'm confused that if the t(expand_ratio) equals to 1 as the first bottleneck, it seems we don't add any layer for this. Besides, does the x.mean([2, 3]) equal to the avgpool(7*7) in the forward pass?

@fmassa
Copy link
Member Author

fmassa commented Apr 2, 2019

@hsparrow Hi,

if the t(expand_ratio) equals to 1 as the first bottleneck, it seems we don't add any layer for this.

Note that range(1) gives one element

does the x.mean([2, 3]) equal to the avgpool(7*7) in the forward pass

Yes, this is equivalent to a `adaptive_avgpool2d(1)

@hsparrow
Copy link

hsparrow commented Apr 2, 2019

@fmassa Thanks for your replying!

Note that range(1) gives one element

But I'm still confusing about this. I mean in the InvertedResidual method, we only add the block to the layers when the expand_ratio doesn't equal to 1. Or, in the first bottleneck where t=1, the self.conv is supposed to be empty?

@fmassa
Copy link
Member Author

fmassa commented Apr 2, 2019

@hsparrow this part of the code is not indented at the same level

layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])

@hsparrow
Copy link

hsparrow commented Apr 2, 2019

@fmassa Got it! My bad!

@jeremyjordan
Copy link

@fmassa do you have an estimate when you'll be able to submit the subsequent PR with the pre-trained ImageNet weights?

@fmassa
Copy link
Member Author

fmassa commented Apr 22, 2019

@jeremyjordan I need to retrain the models again changing a few hyperparameters.
The current model that I have trained some time ago has a top1 accuracy of 70.3, while the expected accuracy is 71.8, so quite a bit lower.

My target date to upload the pre-trained models (which matches accuracies) is by the end of the month, so in ~1 week.

@D-X-Y
Copy link

D-X-Y commented May 3, 2019

@fmassa Would you mind to let me know the parameters that you used to train the MobileNet-V2, such as optimizer, weight decay?

@d-li14
Copy link

d-li14 commented May 5, 2019

For reference, there exist pre-trained MobileNetV2 models with cosine LR decay training strategy in https://github.com/d-li14/mobilenetv2.pytorch, reaching 72.2% accuracy, with identical network definition.

@D-X-Y
Copy link

D-X-Y commented May 5, 2019

@d-li14 Thanks you, I will try it.

@fmassa
Copy link
Member Author

fmassa commented May 6, 2019

@D-X-Y once I get all the models giving the expected results, the hyperparameters for training the models will be available in https://github.com/pytorch/vision/tree/master/references/classification

@matthewygf
Copy link
Contributor

@fmassa Hi thank you for implementing mobilenetv2, would allowing user to experiment with the inverted_residual_setting is part of the future plan of torchvision ?

@fmassa
Copy link
Member Author

fmassa commented May 28, 2019

@matthewygf if you send a PR letting the user specify the inverted_residual_setting, and using the default value for mobilenet_v2, I'd be happy to merge such a PR!

@Mxbonn
Copy link

Mxbonn commented Jun 19, 2019

@fmassa
in #1005 you mention

Yes, the model that I released uses that script.
Here are the arguments I used:

--model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98

Does that mean you only trained with a batch size of 32 and on 1 GPU?

@fmassa
Copy link
Member Author

fmassa commented Jun 19, 2019

@Mxbonn sorry for not being clear, I trained on 8 GPUs, each GPU having a batch size of 32.

     4 Arguments passed:  --model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98
     5 0: | distributed init (rank 0): 
     6 2: | distributed init (rank 2): 
     7 6: | distributed init (rank 6):
     8 4: | distributed init (rank 4):
     9 1: | distributed init (rank 1):
    10 5: | distributed init (rank 5):
    11 7: | distributed init (rank 7):
    12 3: | distributed init (rank 3):
    13 0: Namespace(batch_size=32, data_path='/datasets01/imagenet_full_size/061417/', device='cuda', dist_backend='nccl', dist_url='tcp://learnfair1331:40000', distributed=True, epochs=300, gpu=0, lr=0.045, lr_gamma=0.98, lr_step_size=1,                        model='mobilenet_v2', momentum=0.9, output_dir='/checkpoint/fmassa/jobs/classification_logs/mobilenet_12644656', print_freq=10, rank=0, resume='', sync_bn=False, test_only=False, weight_decay=4e-05, workers=16, world_size=8)

@andravin
Copy link

andravin commented Jul 1, 2019

Please note the reference implementation does not apply weight decay to the depthwise convolutions:
slim.arg_scope([slim.separable_conv2d], weights_regularizer=None) as s:
https://github.com/tensorflow/models/blob/58a3de6c68639c3eac95f7c84089f320d32a85bb/research/slim/nets/mobilenet/mobilenet.py#L474

@andravin
Copy link

andravin commented Jul 1, 2019

Also see this URL for full list of options used by reference implementation to achieve 72% accuracy:
https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet

Edit: While training preprocessing is different, you are already scaling the validation images the same way inception_v2 preprocessing does. The latter seemed most significant in my experiments.
https://github.com/tensorflow/models/blob/696b69a498b43f8e6a1ecb24bb82f7b9db87c570/research/slim/preprocessing/inception_preprocessing.py

@fmassa
Copy link
Member Author

fmassa commented Jul 2, 2019

@andravin Applying those changes could potentially improve the performance, indeed.

The current model that I uploaded yields 71.9 top1 accuracy, which is clone enough I believe, but keeps the training code / transforms the same as for the other models, so that we can more easily factor out improvements in the performance which are orthogonal to the model itself.

@andravin
Copy link

andravin commented Jul 2, 2019

You previously reported 70.8% accuracy. What did you change to get 71.9%?

@fmassa
Copy link
Member Author

fmassa commented Jul 3, 2019

@andravin a few things:

  • number of iterations (from 180 to 300)
  • learning rate schedule (decreasing every 1 epochs with lr_gamma of 0.98, instead of every 2 epochs with lr_gamma of 0.96)
  • smaller batch size (I've at some point tried to fill the whole GPU by increasing the batch size by 10x, and also increased the lr by 10x)
  • do not use sync batch norm during training (results have been consistently worse)

Also, here is the PR that uploaded the new model #917

@andravin
Copy link

andravin commented Jul 3, 2019

Thanks @fmassa .. that leads me to the next question, why such a small batch size (32 x 8 GPU)? The MobileNetV2 paper uses 96 x 16.

BTW, I ran your script using:

python3 -m torch.distributed.launch --nproc_per_node=8\
        train.py --model mobilenet_v2\
        --epochs 300 --lr 0.045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98\
        --batch-size=32

I had to modify train.py in order to add support for the --local_rank flag used by torch.distributed.launch.

Using an AWS EC2 p3.16xlarge instance (8 x V100), the projected running time is 137.7 hours. Is that what you expect?

Torch master checked out yesterday, CUDA 10.1, cuDNN 7.6.

@fmassa
Copy link
Member Author

fmassa commented Jul 4, 2019

@andravin

I had to modify train.py in order to add support for the --local_rank flag used by torch.distributed.launch.

I use the flag --use_env in torch.distributed.launch, so that --local_rank is not needed in the main script.

Using an AWS EC2 p3.16xlarge instance (8 x V100), the projected running time is 137.7 hours. Is that what you expect?

The total time per epoch on my case (using 8 V100 with CUDA 10, not sure which cudnn anymore) was 9 min 15 seconds on average, and the total training time (including periodic evaluation) lasted for 2 days, 0:33:33, so your projected runtime is quite different from what I had.

that leads me to the next question, why such a small batch size (32 x 8 GPU)? The MobileNetV2 paper uses 96 x 16.

No good reason other than I didn't want to use 16GPUs with async updates (like they mentioned in the paper).
I tried a few different configurations with different hyperparameters, and got worse results. But I didn't do an extensive evaluation to be able to justify that this schedule is better than theirs.

@andravin
Copy link

andravin commented Jul 4, 2019

Thanks for the tip about --use_env @fmassa.

With --batch-size=192 I get an epoch time around 8 minutes 10 seconds, but with --batch-size=32 (your setting?) it slows down about 3x.

Any advice how to track down the large performance difference we are reporting?

>>> import torch
torc>>> torch.cuda.nccl.version()
2406
>>> torch.version.cuda
'10.1.168'
>>> torch.backends.cudnn.version()
7601
>>> torch.__version__
'1.2.0a0+ffa15d2'

@fmassa
Copy link
Member Author

fmassa commented Jul 4, 2019

@andravin wow, I've just kicked the same training again with a newer version of PyTorch, and training times are indeed 3x slower.

After some digging, the issue was the same as pytorch/pytorch#20311

Here are some results, on a 80 core machine:

OMP_NUM_THREADS ETA
unset 0:36:56
10 0:19:34
6 0:14:56

So I'd say that the problem in on PyTorch, and we should just tune OMP_NUM_THREADS to a reasonably small value.

I'm using

> print(torch.__version__)
1.2.0.dev20190624

@andravin
Copy link

andravin commented Jul 4, 2019

That's great @fmassa, with OMP_THREADS=6, my --batch-size=192 run now gets 5:48 epochs.

Is your 14:56 (batch) ETA actually worse than the 9 min 15 second epochs you reported earlier, or do your recent measurements include a warmup penalty?

@fmassa
Copy link
Member Author

fmassa commented Jul 4, 2019

@andravin I haven't tried other OMP_NUM_THREADS, there might be a sweet spot where the numbers match what I had before, but as of those values I tried it indeed seems worse

@andravin
Copy link

andravin commented Jul 5, 2019

I get the best performance with OMP_NUM_THREADS=1. For --batch-size=32 the epoch time is 6:35.

With OMP_NUM_THREADS=6 it is 1.5x slower.

All batch sizes benefit from OMP_NUM_THREADS=1, but small batch sizes benefit most.

Note: my pytorch build was NO_MKLDNN=1 python setup.py install although I don't expect that should matter.

@fmassa
Copy link
Member Author

fmassa commented Jul 5, 2019

@andravin very useful information, thanks for sharing!

@andravin
Copy link

andravin commented Jul 8, 2019

@fmassa I was not able to reproduce your training result. The best test accuracy was 71.536% after 290 epochs.

| distributed init (rank 6): env://
| distributed init (rank 7): env://
| distributed init (rank 0): env://
| distributed init (rank 2): env://
| distributed init (rank 5): env://
| distributed init (rank 1): env://
| distributed init (rank 3): env://
| distributed init (rank 4): env://
Namespace(apex=False, apex_opt_level='O1', batch_size=32, cache_dataset=False, data_path='/data/imagenet', device='cuda', dist_backend='nccl', dist_url='env://', distributed=True, epochs=300, gpu=0, lr=0.045, lr_gamma=0.98, lr_step_size=1, model='mobilenet_v2', momentum=0.9, output_dir='/data/trained-models/mobilenet_v2-pytorch-defaults', pretrained=False, print_freq=100, rank=0, resume='', start_epoch=0, sync_bn=False, test_only=False, weight_decay=4e-05, workers=16, world_size=8)

Testing the pretrained model, I get 71.878%.

@fmassa
Copy link
Member Author

fmassa commented Jul 8, 2019

@andravin I could try kicking in some more trainings to verify if this is just due to random noise of the model training, of if something changed since last time I trained it.

But I won't have the time to test it before some time though.

@andravin
Copy link

andravin commented Jul 8, 2019

@fmassa It might have been a mistake for me to use a development version of pytorch for this experiment. I doubt the large accuracy difference can be explained by random variance alone.

Edit: Re-running now using pytorch 1.1 (stable), CUDA 10.0, cuDNN 7.5, torchvision 0.3.

Please advise if this configuration is most likely to reproduce your results.

The single epoch training time is still, remarkably, 6:34, using OMP_NUM_THREADS=1, unchanged from my earlier measurements with recent software. Beware first epoch warmup time when benchmarking, my first epoch was 14:50.

@fmassa
Copy link
Member Author

fmassa commented Jul 9, 2019

@andravin this is indeed the configuration that I used, my log dates from May 8th. If you still can't reproduce the results please do let me know.

@andravin
Copy link

andravin commented Jul 10, 2019

@fmassa using the aforementioned stable build, I get 71.830% accuracy after last epoch, 71.888% after epoch 288.

By the way, did you report last epoch accuracy or best epoch accuracy?

Anyway, success! Training time 1 day, 10:13:30.

@fmassa
Copy link
Member Author

fmassa commented Jul 10, 2019

@andravin I reported best accuracy, which was obtained for me at epoch 285. So our results match fairly closely, which is great!

And it's good to know that there might be some regression that happened between PyTorch 1.1, TorchVision 0.3 and now. What remains to be seen is if it's in PyTorch, TorchVision or both :-)

@andravin
Copy link

@fmassa Would this have broken weight initialization: pytorch/pytorch#22529 ?

@fmassa
Copy link
Member Author

fmassa commented Jul 10, 2019

@andravin it's a good hypothesis, that you were using a version that had the rng bugged for CUDA, but from my understanding the model initialization is first done on the CPU, and then the model is moved to the GPU, so we actually exercise the CPU codepath for the rng

@andravin
Copy link

andravin commented Jul 10, 2019

More mobilenet_v2 training speed tests for torch 1.1, torchvision 0.3, CUDA 10.0, cuDNN 7.5, libjpeg-turbo8:

Using pillow-simd:

OMP_NUM_THREADS workers Epoch training time
2 8 6:42
1 8 6:17
1 16 6:16

Using regular pillow:

OMP_NUM_THREADS workers Epoch training time
6 8 9:56
6 16 9:50
4 16 7:01
1 16 6:34
1 8 6:32

Benchmarking gotchas: 1) ignore first epoch time due to substantial warmup cost, 2) make sure to kill all python processes after interrupting with Ctrl-c.

python3 -m torch.distributed.launch --use_env --nproc_per_node=8\
        train.py --model $MODEL\
        --epochs 300\
        --lr 0.045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98\
        --data-path=/data/imagenet\
        --output-dir=/data/trained-models/$MODEL\
        --batch-size=32\
        --print-freq=100\
        --workers=$WORKERS |& tee -a /data/log/$MODEL/train-$MODEL.log

Again, machine is AWS EC2 p3.16xlarge instance, 8xV100, 480 MB RAM.

$ lscpu
Thread(s) per core:    2
Core(s) per socket:    16
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
Model name:            Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              46080K

@andravin
Copy link

andravin commented Jul 15, 2019

@fmassa I ran into a problem when adding the --apex flag, so I opened #1119. Looks like an easy fix.

But mobilenet_v2 training is slower with mixed precision than without:

OMP_NUM_THREADS=1, workers=16

batch size --apex Epoch training time
32 Yes 7:52
32 No 6:16
128 Yes 5:19
128 No 5:11
256 Yes 5:08
192 No 5:09
256 No out of memory

@dakshjotwani
Copy link
Contributor

But mobilenet_v2 training is slower with mixed precision than without

@andravin See NVIDIA/apex#76 for a discussion on this topic. Mixed precision training can improve training times if your GPU has tensor cores.

Mixed precision training, however, can reduce your memory usage (your results on a batch size of 256 show this).

This article covers this topic quite well.

@andravin
Copy link

@dakshjotwani V100 GPUs, as documented above, have tensor cores.

@dakshjotwani
Copy link
Contributor

@andravin my bad. I wasn't sure about the GPUs being used. I remembered facing a similar issue before and thought it might be useful to the issue.

@andravin
Copy link

@dakshjotwani Understandable, because we would not expect a tensor core enabled GPU to run slower with mixed precision.

The throughput per GPU for batch size 256 is actually less than 1 TFLOP, which is less than 1% utilization of the tensor cores.

I get a 2.3x speedup by eliminating the data loader, but that still yields less than 2% utilization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.