Skip to content

[DRAFT] Mobile ViT #6965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft

Conversation

yassineAlouini
Copy link
Contributor

@yassineAlouini yassineAlouini commented Nov 20, 2022

V1 of the MobileViT PR.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yassineAlouini Thanks a lot for the PR! I've added a few nits and questions below. Let me know your thoughts.

For next steps, I propose to:

  1. Verify that the implementation produces the same accuracies with ported weights from the paper. This should be done with our existing reference scripts, so we need from you the command you ran and its output that show-cases the accuracy.
  2. Make amendments on the reference scripts (if any) what will allow you to reproduce the training setup from the paper and provide to us the command that you would normally run to train it. We can then checkout your PR and run it on our internal infra.

If performing step 1 is also challenging for you due to lack of infra, we can do the same as in step 2 and run it on our side. Just make sure your PR contains all the changes required so that we can run it on our infra by submitting your provided command. In that case, you would need to provide a reference of the weights you use from the original repo and the translation/porting script between the two implementations.

# TODO: What about multi-scale sampler? Check later for V2 training...


class MobileViT_Weights(WeightsEnum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: Could you move the weights definition and their common vars after the model declaration? This is the structure we have right now for all other models. Ideally we should find a way on the future to enforce this order via the linter or stop enforcing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will move of course. 👌

Comment on lines 43 to 46
# TODO: Update with the correct values. For now, these are the expected ones from the paper.
"ImageNet-1K": {
"acc@1": 78.4,
"acc@5": 94.1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should retrain from scratch. For first step, have you confirmed that the ported weights yield the same accuracy using TorchVision's reference scripts?

self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)

# MLP block (inspired from swin_transformer.py)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yassineAlouini If I understand correctly, the only changes you need is changing the activation_layer, inplace value and the lack of ln_2. What is your thoughts on reusing the block from vision_transformer.py?

@YosuaMichael, @jdsgomes any thoughts on whether we want to inherit/extend from above? This building block is tricky because it's not as standard as other common blocks. Moreover so far we tried to avoid importing directly from other model architectures. What is your advice here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I think I prefer to not importing from vision_transformer.py. If we have more transformer models that require this TransformerEncoderBlock I think we should create a generic one under torchvision.ops.misc like what we did in MLP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should avoid importing from other model architectures so I like Yosua's proposal, we can at a later stage move this to ops.misc. Just wondering if this should be explicitly marked as private in that case?

tensors = []
# Here we loop over the P pixels of the
# tensor x of shape: B, P, N, d
for p in range(x.shape[1]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that relying on the shape of the input will make the model not FX-traceable. Can you confirm? If the backbone is not traceable, we should try to implement this in a way that doesn't read the shape of input at runtime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I will check how it behaves with torch.fx and report back. 👌



# TODO: Is this separable self-attention?
# TODO: Is this necessary? Probably for the V2 version...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we temporarily remove all TODO V2 references to reduce the surface of the PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good idea indeed. I will make another PR for V2. 👍

Comment on lines 352 to 354
x = self.conv_last(x)
x = torch.flatten(x, 1)
x = self.fc(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Quite a few models pack these under classifier:

self.classifier = nn.Sequential(
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will change accordingly.

@@ -286,7 +286,6 @@ def __init__(
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: undo

@datumbox datumbox marked this pull request as draft November 28, 2022 08:43
@datumbox
Copy link
Contributor

@yassineAlouini Thanks for the work on it. I've marked the PR as draft. Feel free to unmark it back to indicate you are happy for us to have another review. Might be worth resolving the conflicts and merging main branch.

@yassineAlouini
Copy link
Contributor Author

@yassineAlouini Thanks for the work on it. I've marked the PR as draft. Feel free to unmark it back to indicate you are happy for us to have another review. Might be worth resolving the conflicts and merging main branch.

Noted, thanks @datumbox. 👌

@yassineAlouini
Copy link
Contributor Author

yassineAlouini commented Dec 27, 2022

@datumbox I have a new desktop with a GPU and trying to train the model using it. However, I might need some help/guidance.

The first error I get is the following:

AttributeError: module 'torch._C' has no attribute '_cuda_setDevice'

Following the stacktrace, I get to this file: torch/cuda/__init__.py then line 337, in set_device

Then, I commented the line there:

def set_device(device: _device_t) -> None:
    r"""Sets the current device.

    Usage of this function is discouraged in favor of :any:`device`. In most
    cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.

    Args:
        device (torch.device or int): selected device. This function is a no-op
            if this argument is negative.
    """
    device = _get_device_index(device)
    # if device >= 0:
    #     torch._C._cuda_setDevice(device)

The next error I get is the following:

RuntimeError: Distributed package doesn't have NCCL built in

I tried installing nccl but doesn't change the outcome.
I am into uncharted territory here so any help is appreciated, thanks in advance!

For more context, I have checked that my GPU is visible by PyTorch and I can use it to run dummy models.

@yassineAlouini
Copy link
Contributor Author

I am training on multiple cloud GPUs and will report back how it goes.

@yassineAlouini
Copy link
Contributor Author

@datumbox I guess I won't receive the access to the ImageNet dataset. Should someone take over this work? 🤔

@datumbox
Copy link
Contributor

@yassineAlouini Might be worth reaching out to the team either on the Github issue or slack to see whether it's possible for someone else to train it. I've no longer have access on the cluster unfortunately, so I can't help on this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants