-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[DRAFT] Mobile ViT #6965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[DRAFT] Mobile ViT #6965
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yassineAlouini Thanks a lot for the PR! I've added a few nits and questions below. Let me know your thoughts.
For next steps, I propose to:
- Verify that the implementation produces the same accuracies with ported weights from the paper. This should be done with our existing reference scripts, so we need from you the command you ran and its output that show-cases the accuracy.
- Make amendments on the reference scripts (if any) what will allow you to reproduce the training setup from the paper and provide to us the command that you would normally run to train it. We can then checkout your PR and run it on our internal infra.
If performing step 1 is also challenging for you due to lack of infra, we can do the same as in step 2 and run it on our side. Just make sure your PR contains all the changes required so that we can run it on our infra by submitting your provided command. In that case, you would need to provide a reference of the weights you use from the original repo and the translation/porting script between the two implementations.
torchvision/models/mobilevit.py
Outdated
# TODO: What about multi-scale sampler? Check later for V2 training... | ||
|
||
|
||
class MobileViT_Weights(WeightsEnum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit: Could you move the weights definition and their common vars after the model declaration? This is the structure we have right now for all other models. Ideally we should find a way on the future to enforce this order via the linter or stop enforcing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will move of course. 👌
torchvision/models/mobilevit.py
Outdated
# TODO: Update with the correct values. For now, these are the expected ones from the paper. | ||
"ImageNet-1K": { | ||
"acc@1": 78.4, | ||
"acc@5": 94.1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we should retrain from scratch. For first step, have you confirmed that the ported weights yield the same accuracy using TorchVision's reference scripts?
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
# MLP block (inspired from swin_transformer.py) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yassineAlouini If I understand correctly, the only changes you need is changing the activation_layer
, inplace
value and the lack of ln_2
. What is your thoughts on reusing the block from vision_transformer.py
?
@YosuaMichael, @jdsgomes any thoughts on whether we want to inherit/extend from above? This building block is tricky because it's not as standard as other common blocks. Moreover so far we tried to avoid importing directly from other model architectures. What is your advice here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I think I prefer to not importing from vision_transformer.py
. If we have more transformer models that require this TransformerEncoderBlock
I think we should create a generic one under torchvision.ops.misc
like what we did in MLP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we should avoid importing from other model architectures so I like Yosua's proposal, we can at a later stage move this to ops.misc. Just wondering if this should be explicitly marked as private in that case?
tensors = [] | ||
# Here we loop over the P pixels of the | ||
# tensor x of shape: B, P, N, d | ||
for p in range(x.shape[1]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect that relying on the shape of the input will make the model not FX-traceable. Can you confirm? If the backbone is not traceable, we should try to implement this in a way that doesn't read the shape of input at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, I will check how it behaves with torch.fx
and report back. 👌
torchvision/models/mobilevit.py
Outdated
|
||
|
||
# TODO: Is this separable self-attention? | ||
# TODO: Is this necessary? Probably for the V2 version... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we temporarily remove all TODO V2
references to reduce the surface of the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is a good idea indeed. I will make another PR for V2. 👍
torchvision/models/mobilevit.py
Outdated
x = self.conv_last(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Quite a few models pack these under classifier
:
vision/torchvision/models/convnext.py
Lines 159 to 161 in 4a310f2
self.classifier = nn.Sequential( | |
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will change accordingly.
torchvision/ops/misc.py
Outdated
@@ -286,7 +286,6 @@ def __init__( | |||
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: | |||
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py | |||
params = {} if inplace is None else {"inplace": inplace} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: undo
@yassineAlouini Thanks for the work on it. I've marked the PR as draft. Feel free to unmark it back to indicate you are happy for us to have another review. Might be worth resolving the conflicts and merging main branch. |
Noted, thanks @datumbox. 👌 |
@datumbox I have a new desktop with a GPU and trying to train the model using it. However, I might need some help/guidance. The first error I get is the following:
Following the stacktrace, I get to this file: Then, I commented the line there:
The next error I get is the following:
I tried installing For more context, I have checked that my GPU is visible by PyTorch and I can use it to run dummy models. |
I am training on multiple cloud GPUs and will report back how it goes. |
@datumbox I guess I won't receive the access to the ImageNet dataset. Should someone take over this work? 🤔 |
@yassineAlouini Might be worth reaching out to the team either on the Github issue or slack to see whether it's possible for someone else to train it. I've no longer have access on the cluster unfortunately, so I can't help on this one. |
V1 of the MobileViT PR.