From 2d82685d465b99084cef8c056f268438f477483f Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Thu, 18 Aug 2022 16:55:46 +0200 Subject: [PATCH 01/15] [FEAT] Start the mobile vit model. --- docs/source/models/mobilevit.rst | 27 +++++++++++++++++++++++++++ torchvision/models/__init__.py | 1 + torchvision/models/mobilevit.py | 27 +++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) create mode 100644 docs/source/models/mobilevit.rst create mode 100644 torchvision/models/mobilevit.py diff --git a/docs/source/models/mobilevit.rst b/docs/source/models/mobilevit.rst new file mode 100644 index 00000000000..8e64f1b9d1d --- /dev/null +++ b/docs/source/models/mobilevit.rst @@ -0,0 +1,27 @@ +.. + _TODO: Update the documentation with the correct links... + + +MobileViT +=========== + +.. currentmodule:: torchvision.models + +The MobileViT model is based on the `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `__ paper. + + +Model builders +-------------- + +The following model builders can be used to instantiate a MobileViT model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.mobilevit.MobileViT`` base class. Please refer to the `source +code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + mobilevit diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index eb949fb3d5c..955d6dfa433 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -13,5 +13,6 @@ from .vgg import * from .vision_transformer import * from .swin_transformer import * +from .mobilevit import * from . import detection, optical_flow, quantization, segmentation, video from ._api import get_model, get_model_weights, get_weight, list_models diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py new file mode 100644 index 00000000000..9fcc6e1f16b --- /dev/null +++ b/torchvision/models/mobilevit.py @@ -0,0 +1,27 @@ +# TODO: Implement v1 and v2 versions of the mobile ViT model. + +from torch import nn +from torchvision.utils import _log_api_usage_once + +__all__ = ["MobileViT"] + +# TODO: Update this... +# Paper links: v1 https://arxiv.org/abs/2110.02178 +# v2 +class MobileViT(nn.Module): + """ + Implements MobileViT from the `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `_ paper. + Args: + TODO: Arguments to be updated... + """ + + def __init__( + self, + ): + super().__init__() + _log_api_usage_once(self) + # TODO: Add blocks... + + # TODO: This is the core thing to implement... + def forward(self, x): + return x From acec4142ca49939240860ba8c07a0ffa1158a290 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Thu, 25 Aug 2022 11:32:33 +0200 Subject: [PATCH 02/15] Start the mobile_vit implementation. WIP. --- torchvision/models/mobilevit.py | 116 +++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 3 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index 9fcc6e1f16b..9f7406561a1 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -2,12 +2,81 @@ from torch import nn from torchvision.utils import _log_api_usage_once +from torchvision.models._api import register_model, Weights, WeightsEnum +from torchvision.models._utils import _ovewrite_named_param +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torch import Tensor +from typing import Callable, Optional, Any, List +from ..transforms._presets import ImageClassification +from functools import partial -__all__ = ["MobileViT"] +__all__ = ["MobileViT", "MobileViT_Weights", "MobileViT_V2_Weights"] + +# TODO: Is this correct? Maybe not? Need to check the training script... +_COMMON_META = { + "categories": _IMAGENET_CATEGORIES, +} # TODO: Update this... # Paper links: v1 https://arxiv.org/abs/2110.02178 -# v2 +# v2 (what the difference with the V1 paper?) +# TODO: Need a mobile ViT block... +# TODO: Adding weights... Start with V1. +# Things to be done: write the V1, mobileViTblock, weights, documentation... + +class MobileViT_Weights(WeightsEnum): + # TODO: Update these... + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilevit.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.878, + "acc@5": 90.286, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + # TODO: Will be updated later... + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilevit.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.154, + "acc@5": 90.822, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class MobileViT_V2_Weights(WeightsEnum): + pass + + + +class MobileViTBlock(nn.Module): + def forward(self, x: Tensor): + return x + +class MobileViTV2Block(MobileViTBlock): + def forward(self, x: Tensor): + return x + class MobileViT(nn.Module): """ Implements MobileViT from the `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `_ paper. @@ -17,11 +86,52 @@ class MobileViT(nn.Module): def __init__( self, + num_classes: int, + # TODO: Should this be optional? + block: Optional[Callable[..., nn.Module]] = None, ): super().__init__() _log_api_usage_once(self) - # TODO: Add blocks... + # TODO: Add blocks... In progress... + self.num_classes = num_classes + + if block is None: + block = MobileViTBlock # TODO: This is the core thing to implement... def forward(self, x): return x + + +def _mobile_vit( + # TODO: Update the parameters... + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> MobileViT: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = MobileViT( + # TODO: Update these...Will pass different configurations depending on the size of the mdoel... + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +@register_model() +def mobile_vit_s(): + pass + + +@register_model() +def mobile_vit_s(): + pass + +@register_model() +def mobile_vit_s(): + pass \ No newline at end of file From 1276cd9cc57981aa86deb7d7a86215ab8ce41598 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Wed, 31 Aug 2022 16:50:02 +0200 Subject: [PATCH 03/15] More progress: weights structuring and some comments. --- torchvision/models/mobilevit.py | 128 +++++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 34 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index 9f7406561a1..586923c1237 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -12,56 +12,78 @@ __all__ = ["MobileViT", "MobileViT_Weights", "MobileViT_V2_Weights"] -# TODO: Is this correct? Maybe not? Need to check the training script... _COMMON_META = { "categories": _IMAGENET_CATEGORIES, } -# TODO: Update this... -# Paper links: v1 https://arxiv.org/abs/2110.02178 +# For V1, we have 3 sets of weights xx_small (1.3M parameters), x_small (2.3M parameters), and small (5.6M parameters) +# For V2, we have one set of weights. +# Paper link: v1 https://arxiv.org/abs/2110.02178. +# Paper link: v2 https://arxiv.org/pdf/2206.02680.pdf. # v2 (what the difference with the V1 paper?) -# TODO: Need a mobile ViT block... -# TODO: Adding weights... Start with V1. -# Things to be done: write the V1, mobileViTblock, weights, documentation... +# Things to be done: write the V1, MobileViTblock, MobileViTV2block, weights (for V1 and V2), documentation... +# TODO: What about multi-scale sampler? class MobileViT_Weights(WeightsEnum): - # TODO: Update these... IMAGENET1K_V1 = Weights( + # TODO: Update the URL once the model has been trained... url="https://download.pytorch.org/models/mobilevit.pth", - transforms=partial(ImageClassification, crop_size=224), + transforms=partial(ImageClassification, crop_size=256), meta={ **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", "_metrics": { + # TODO: Update with the correct values. For now, these are the expected ones from the paper. "ImageNet-1K": { - "acc@1": 71.878, - "acc@5": 90.286, + "acc@1": 78.4, + "acc@5": 94.1, } }, "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", }, ) - # TODO: Will be updated later... - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/mobilevit.pth", - transforms=partial(ImageClassification, crop_size=224, resize_size=232), + DEFAULT = IMAGENET1K_V1 + +class MobileViT_XS_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # TODO: Update the URL once the model has been trained... + url="https://download.pytorch.org/models/mobilevit_xs.pth", + transforms=partial(ImageClassification, crop_size=256), meta={ **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", "_metrics": { + # TODO: Update with the correct values. For now, these are the expected ones from the paper. "ImageNet-1K": { - "acc@1": 72.154, - "acc@5": 90.822, + "acc@1": 74.8, + "acc@5": 92.3, } }, - "_docs": """ - These weights improve upon the results of the original paper by using a modified version of TorchVision's - `new training recipe - `_. - """, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", }, ) - DEFAULT = IMAGENET1K_V2 + DEFAULT = IMAGENET1K_V1 + + +class MobileViT_XXS_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # TODO: Update the URL once the model has been trained... + url="https://download.pytorch.org/models/mobilevit_xxs.pth", + transforms=partial(ImageClassification, crop_size=256), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", + "_metrics": { + # TODO: Update with the correct values. For now, these are the expected ones from the paper. + "ImageNet-1K": { + "acc@1": 69.0, + "acc@5": 88.9, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 class MobileViT_V2_Weights(WeightsEnum): @@ -73,6 +95,8 @@ class MobileViTBlock(nn.Module): def forward(self, x: Tensor): return x + +# Separable self-attention class MobileViTV2Block(MobileViTBlock): def forward(self, x: Tensor): return x @@ -81,13 +105,17 @@ class MobileViT(nn.Module): """ Implements MobileViT from the `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `_ paper. Args: - TODO: Arguments to be updated... + TODO: Arguments to be updated... in progress + num_classes (int): Number of classes for classification head. Default: 1000. + layers_conf (dict): The layers configuration. """ def __init__( self, - num_classes: int, - # TODO: Should this be optional? + # Trained on ImageNet1K by default. + num_classes: int = 1000, + layers_conf: dict = None, + # TODO: Should this be optional? Yes probably... block: Optional[Callable[..., nn.Module]] = None, ): super().__init__() @@ -97,9 +125,14 @@ def __init__( if block is None: block = MobileViTBlock + # Build the model one layer at a time. + layers: List[nn.Module] = [] + self.features = nn.Sequential(*layers) + # TODO: This is the core thing to implement... def forward(self, x): + x = self.features(x) return x @@ -122,16 +155,43 @@ def _mobile_vit( return model - @register_model() -def mobile_vit_s(): - pass +def mobile_vit_s(*, weights: Optional[MobileViT_Weights] = None, progress: bool = True, **kwargs: Any): + """ + Constructs a mobile_vit_s architecture from + `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `_. + + Args: + weights (:class:`~torchvision.models.MobileViT_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.MobileVit`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.MobileViT_Weights + :members: + """ + weights = MobileViT_Weights.verify(weights) + return _mobile_vit(weights=weights) @register_model() -def mobile_vit_s(): - pass +def mobile_vit_xs(): + weights = MobileViT_XS_Weights.verify(weights) + return _mobile_vit(weights=weights) + +@register_model() +def mobile_vit_xxs(): + weights = MobileViT_XXS_Weights.verify(weights) + return _mobile_vit(weights=weights) @register_model() -def mobile_vit_s(): - pass \ No newline at end of file +def mobile_vit_v2() + weights = MobileViT_V2_Weights.verify(weights) + return _mobile_vit(weights=weights) \ No newline at end of file From 408bbce903a3e1ee141e4f28f46dd75210466257 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Wed, 2 Nov 2022 20:45:02 +0100 Subject: [PATCH 04/15] Some progress but still not you done. --- torchvision/models/mobilevit.py | 151 ++++++++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 7 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index 586923c1237..f57e478df2b 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -2,13 +2,17 @@ from torch import nn from torchvision.utils import _log_api_usage_once +from torchvision.models.mobilenetv2 import MobileNetV2 from torchvision.models._api import register_model, Weights, WeightsEnum from torchvision.models._utils import _ovewrite_named_param from torchvision.models._meta import _IMAGENET_CATEGORIES from torch import Tensor from typing import Callable, Optional, Any, List from ..transforms._presets import ImageClassification +from ..ops.misc import MLP from functools import partial +from collections import OrderedDict +import torch __all__ = ["MobileViT", "MobileViT_Weights", "MobileViT_V2_Weights"] @@ -22,7 +26,7 @@ # Paper link: v2 https://arxiv.org/pdf/2206.02680.pdf. # v2 (what the difference with the V1 paper?) # Things to be done: write the V1, MobileViTblock, MobileViTV2block, weights (for V1 and V2), documentation... -# TODO: What about multi-scale sampler? +# TODO: What about multi-scale sampler? Check later... class MobileViT_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( @@ -85,18 +89,146 @@ class MobileViT_XXS_Weights(WeightsEnum): ) DEFAULT = IMAGENET1K_V1 - +# TODO: Take inspiration from the V1 weights... In progress... class MobileViT_V2_Weights(WeightsEnum): pass +# The EncoderBlock and Encoder from vision_transformer.py +# TODO: Maybe refactor later... +class TransformerEncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + + # MLP block (inspired from swin_transformer.py) + self.mlp = MLP(mlp_dim, [hidden_dim, mlp_dim], + activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.mlp.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + x = self.ln_1(input) + x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) + x = self.dropout(x) + x = x + input + y = self.mlp(y) + return x + y + + +class TransformerEncoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + # Multiple + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = TransformerEncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.layers = nn.Sequential(layers) + self.ln = norm_layer(hidden_dim) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + input = input + self.pos_embedding + return self.ln(self.layers(self.dropout(input))) + +# TODO: We will need a mobilenet block as well. +# TODO: We need to use a Transformer. In progress... Using the one from TorchVision... +# TODO: We need a LayerNorm as well...In progress... class MobileViTBlock(nn.Module): - def forward(self, x: Tensor): - return x + def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.): + super().__init__() + self.patch_height, self.patch_width = patch_dimensions + self.conv1 = nn.Sequential( + nn.Conv2d(channel, channel, kernel_size, 1, bias=False), + nn.BatchNorm2d(channel), + nn.SiLU()) + # Point-wise convolution (1 x 1) + self.conv2 = nn.Sequential( + nn.Conv2d(channel, dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(dim), + nn.SiLU() + ) + # TODO: Setup the inputs... + self.transformer = TransformerEncoder(dim, depth, 4, 8, mlp_dim, dropout) + + self.conv3 = nn.Sequential( + nn.Conv2d(dim, channel, 1, 1, 0, bias=False), + nn.BatchNorm2d(channel), + nn.SiLU()) + self.conv4 = nn.Sequential( + nn.Conv2d(2 * channel, channel, kernel_size, 1, bias=False), + nn.BatchNorm2d(channel), + nn.SiLU()) + + + def forward(self, x): + y = x.copy() + x = self.conv1(x) + x = self.conv2(x) + # batch, channels, height, width. + _, _, h, w = x.shape + # This is the unfloding (from spatial features to patches) and folding (from patches back to features) parts. + # TODO: What are the values of self.ph and self.pw. + # TODO: Change with a PyTorch operation... In progress... + print(x.shape) + """ + x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) + x = self.transformer(x) + # The reverse operation... + x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) + x = self.conv3(x) + x = torch.cat((x, y), 1) + x = self.conv4(x) + """ + return x # Separable self-attention +# TODO: Is this necessary? Check... Maybe class MobileViTV2Block(MobileViTBlock): def forward(self, x: Tensor): return x @@ -147,6 +279,7 @@ def _mobile_vit( model = MobileViT( # TODO: Update these...Will pass different configurations depending on the size of the mdoel... + # In progress... **kwargs, ) @@ -169,7 +302,7 @@ def mobile_vit_s(*, weights: Optional[MobileViT_Weights] = None, progress: bool weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.swin_transformer.MobileVit`` + **kwargs: parameters passed to the ``torchvision.models.mobile_vit.MobileVit`` base class. Please refer to the `source code `_ for more details about this class. @@ -192,6 +325,10 @@ def mobile_vit_xxs(): return _mobile_vit(weights=weights) @register_model() -def mobile_vit_v2() +def mobile_vit_v2(): weights = MobileViT_V2_Weights.verify(weights) - return _mobile_vit(weights=weights) \ No newline at end of file + return _mobile_vit(weights=weights) + + +if __name__ == "__main__": + print(MobileViTBlock(1, 3, 1, 1, 0.5)) \ No newline at end of file From 3beaaa30c3a48728de84d1898c5f8d09e66e6404 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Thu, 3 Nov 2022 20:09:53 +0100 Subject: [PATCH 05/15] Fix relative import. --- torchvision/models/mobilevit.py | 96 ++++++++++++++++----------------- 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index f57e478df2b..fbc21bc7d37 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -1,18 +1,19 @@ # TODO: Implement v1 and v2 versions of the mobile ViT model. -from torch import nn -from torchvision.utils import _log_api_usage_once -from torchvision.models.mobilenetv2 import MobileNetV2 -from torchvision.models._api import register_model, Weights, WeightsEnum -from torchvision.models._utils import _ovewrite_named_param -from torchvision.models._meta import _IMAGENET_CATEGORIES -from torch import Tensor -from typing import Callable, Optional, Any, List -from ..transforms._presets import ImageClassification -from ..ops.misc import MLP -from functools import partial from collections import OrderedDict +from functools import partial +from typing import Any, Callable, List, Optional + import torch +from torch import nn, Tensor +from torchvision.models._api import register_model, Weights, WeightsEnum +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import _ovewrite_named_param +from torchvision.models.mobilenetv2 import MobileNetV2 +from torchvision.utils import _log_api_usage_once + +from torchvision.ops.misc import MLP +from torchvision.transforms._presets import ImageClassification __all__ = ["MobileViT", "MobileViT_Weights", "MobileViT_V2_Weights"] @@ -20,24 +21,25 @@ "categories": _IMAGENET_CATEGORIES, } -# For V1, we have 3 sets of weights xx_small (1.3M parameters), x_small (2.3M parameters), and small (5.6M parameters) -# For V2, we have one set of weights. +# For V1, we have 3 sets of weights xx_small (1.3M parameters), x_small (2.3M parameters), and small (5.6M parameters) +# For V2, we have one set of weights. # Paper link: v1 https://arxiv.org/abs/2110.02178. -# Paper link: v2 https://arxiv.org/pdf/2206.02680.pdf. +# Paper link: v2 https://arxiv.org/pdf/2206.02680.pdf. # v2 (what the difference with the V1 paper?) # Things to be done: write the V1, MobileViTblock, MobileViTV2block, weights (for V1 and V2), documentation... -# TODO: What about multi-scale sampler? Check later... +# TODO: What about multi-scale sampler? Check later... + class MobileViT_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - # TODO: Update the URL once the model has been trained... + # TODO: Update the URL once the model has been trained... url="https://download.pytorch.org/models/mobilevit.pth", transforms=partial(ImageClassification, crop_size=256), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", "_metrics": { - # TODO: Update with the correct values. For now, these are the expected ones from the paper. + # TODO: Update with the correct values. For now, these are the expected ones from the paper. "ImageNet-1K": { "acc@1": 78.4, "acc@5": 94.1, @@ -48,16 +50,17 @@ class MobileViT_Weights(WeightsEnum): ) DEFAULT = IMAGENET1K_V1 + class MobileViT_XS_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - # TODO: Update the URL once the model has been trained... + # TODO: Update the URL once the model has been trained... url="https://download.pytorch.org/models/mobilevit_xs.pth", transforms=partial(ImageClassification, crop_size=256), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", "_metrics": { - # TODO: Update with the correct values. For now, these are the expected ones from the paper. + # TODO: Update with the correct values. For now, these are the expected ones from the paper. "ImageNet-1K": { "acc@1": 74.8, "acc@5": 92.3, @@ -71,14 +74,14 @@ class MobileViT_XS_Weights(WeightsEnum): class MobileViT_XXS_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - # TODO: Update the URL once the model has been trained... + # TODO: Update the URL once the model has been trained... url="https://download.pytorch.org/models/mobilevit_xxs.pth", transforms=partial(ImageClassification, crop_size=256), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", "_metrics": { - # TODO: Update with the correct values. For now, these are the expected ones from the paper. + # TODO: Update with the correct values. For now, these are the expected ones from the paper. "ImageNet-1K": { "acc@1": 69.0, "acc@5": 88.9, @@ -89,10 +92,12 @@ class MobileViT_XXS_Weights(WeightsEnum): ) DEFAULT = IMAGENET1K_V1 + # TODO: Take inspiration from the V1 weights... In progress... class MobileViT_V2_Weights(WeightsEnum): pass + # The EncoderBlock and Encoder from vision_transformer.py # TODO: Maybe refactor later... class TransformerEncoderBlock(nn.Module): @@ -116,8 +121,7 @@ def __init__( self.dropout = nn.Dropout(dropout) # MLP block (inspired from swin_transformer.py) - self.mlp = MLP(mlp_dim, [hidden_dim, mlp_dim], - activation_layer=nn.GELU, inplace=None, dropout=dropout) + self.mlp = MLP(mlp_dim, [hidden_dim, mlp_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) for m in self.mlp.modules(): if isinstance(m, nn.Linear): @@ -155,7 +159,7 @@ def __init__( self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT self.dropout = nn.Dropout(dropout) layers: OrderedDict[str, nn.Module] = OrderedDict() - # Multiple + # Multiple for i in range(num_layers): layers[f"encoder_layer_{i}"] = TransformerEncoderBlock( num_heads, @@ -173,37 +177,28 @@ def forward(self, input: torch.Tensor): input = input + self.pos_embedding return self.ln(self.layers(self.dropout(input))) + # TODO: We will need a mobilenet block as well. # TODO: We need to use a Transformer. In progress... Using the one from TorchVision... # TODO: We need a LayerNorm as well...In progress... class MobileViTBlock(nn.Module): - def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.): + def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.0): super().__init__() self.patch_height, self.patch_width = patch_dimensions self.conv1 = nn.Sequential( - nn.Conv2d(channel, channel, kernel_size, 1, bias=False), - nn.BatchNorm2d(channel), - nn.SiLU()) - # Point-wise convolution (1 x 1) - self.conv2 = nn.Sequential( - nn.Conv2d(channel, dim, 1, 1, 0, bias=False), - nn.BatchNorm2d(dim), - nn.SiLU() + nn.Conv2d(channel, channel, kernel_size, 1, bias=False), nn.BatchNorm2d(channel), nn.SiLU() ) + # Point-wise convolution (1 x 1) + self.conv2 = nn.Sequential(nn.Conv2d(channel, dim, 1, 1, 0, bias=False), nn.BatchNorm2d(dim), nn.SiLU()) # TODO: Setup the inputs... self.transformer = TransformerEncoder(dim, depth, 4, 8, mlp_dim, dropout) - self.conv3 = nn.Sequential( - nn.Conv2d(dim, channel, 1, 1, 0, bias=False), - nn.BatchNorm2d(channel), - nn.SiLU()) + self.conv3 = nn.Sequential(nn.Conv2d(dim, channel, 1, 1, 0, bias=False), nn.BatchNorm2d(channel), nn.SiLU()) self.conv4 = nn.Sequential( - nn.Conv2d(2 * channel, channel, kernel_size, 1, bias=False), - nn.BatchNorm2d(channel), - nn.SiLU()) - + nn.Conv2d(2 * channel, channel, kernel_size, 1, bias=False), nn.BatchNorm2d(channel), nn.SiLU() + ) def forward(self, x): y = x.copy() @@ -224,7 +219,7 @@ def forward(self, x): x = torch.cat((x, y), 1) x = self.conv4(x) """ - return x + return x # Separable self-attention @@ -233,6 +228,7 @@ class MobileViTV2Block(MobileViTBlock): def forward(self, x: Tensor): return x + class MobileViT(nn.Module): """ Implements MobileViT from the `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `_ paper. @@ -244,10 +240,10 @@ class MobileViT(nn.Module): def __init__( self, - # Trained on ImageNet1K by default. + # Trained on ImageNet1K by default. num_classes: int = 1000, layers_conf: dict = None, - # TODO: Should this be optional? Yes probably... + # TODO: Should this be optional? Yes probably... block: Optional[Callable[..., nn.Module]] = None, ): super().__init__() @@ -257,11 +253,10 @@ def __init__( if block is None: block = MobileViTBlock - # Build the model one layer at a time. + # Build the model one layer at a time. layers: List[nn.Module] = [] self.features = nn.Sequential(*layers) - # TODO: This is the core thing to implement... def forward(self, x): x = self.features(x) @@ -269,7 +264,7 @@ def forward(self, x): def _mobile_vit( - # TODO: Update the parameters... + # TODO: Update the parameters... weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, @@ -278,7 +273,7 @@ def _mobile_vit( _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = MobileViT( - # TODO: Update these...Will pass different configurations depending on the size of the mdoel... + # TODO: Update these...Will pass different configurations depending on the size of the mdoel... # In progress... **kwargs, ) @@ -288,6 +283,7 @@ def _mobile_vit( return model + @register_model() def mobile_vit_s(*, weights: Optional[MobileViT_Weights] = None, progress: bool = True, **kwargs: Any): """ @@ -319,11 +315,13 @@ def mobile_vit_xs(): weights = MobileViT_XS_Weights.verify(weights) return _mobile_vit(weights=weights) + @register_model() def mobile_vit_xxs(): weights = MobileViT_XXS_Weights.verify(weights) return _mobile_vit(weights=weights) + @register_model() def mobile_vit_v2(): weights = MobileViT_V2_Weights.verify(weights) @@ -331,4 +329,4 @@ def mobile_vit_v2(): if __name__ == "__main__": - print(MobileViTBlock(1, 3, 1, 1, 0.5)) \ No newline at end of file + print(MobileViTBlock(1, 3, 1, 1, 0.5)) From 55bb81c31e50945e6b02912b482a486bb2f4cd12 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Thu, 3 Nov 2022 20:48:39 +0100 Subject: [PATCH 06/15] MobileViTBlock runs. --- torchvision/models/mobilevit.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index fbc21bc7d37..2e330ab9a1f 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -184,7 +184,8 @@ def forward(self, input: torch.Tensor): class MobileViTBlock(nn.Module): - def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.0): + def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.0, + attention_dropout=0.5): super().__init__() self.patch_height, self.patch_width = patch_dimensions self.conv1 = nn.Sequential( @@ -193,7 +194,7 @@ def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, # Point-wise convolution (1 x 1) self.conv2 = nn.Sequential(nn.Conv2d(channel, dim, 1, 1, 0, bias=False), nn.BatchNorm2d(dim), nn.SiLU()) # TODO: Setup the inputs... - self.transformer = TransformerEncoder(dim, depth, 4, 8, mlp_dim, dropout) + self.transformer = TransformerEncoder(dim, depth, 4, 8, mlp_dim, dropout, attention_dropout) self.conv3 = nn.Sequential(nn.Conv2d(dim, channel, 1, 1, 0, bias=False), nn.BatchNorm2d(channel), nn.SiLU()) self.conv4 = nn.Sequential( @@ -201,7 +202,7 @@ def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, ) def forward(self, x): - y = x.copy() + y = x.detach().clone() x = self.conv1(x) x = self.conv2(x) # batch, channels, height, width. @@ -329,4 +330,9 @@ def mobile_vit_v2(): if __name__ == "__main__": - print(MobileViTBlock(1, 3, 1, 1, 0.5)) + # dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.0 + block = MobileViTBlock(dim=1, depth=3, channel=3, kernel_size=3, patch_dimensions=(2, 2), + mlp_dim=2, dropout=0.5) + print(block) + x = torch.rand(3, 3, 3, 3) + print(block(x)) \ No newline at end of file From 6acde624ad32a988353307cc613de2776a3d105c Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 20 Nov 2022 17:52:29 +0100 Subject: [PATCH 07/15] MobileVit (the v1 version) runs. Next: train. --- torchvision/models/mobilevit.py | 279 ++++++++++++++++------- torchvision/models/vision_transformer.py | 2 + torchvision/ops/misc.py | 2 +- 3 files changed, 199 insertions(+), 84 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index 2e330ab9a1f..75c686ae3a5 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -2,32 +2,33 @@ from collections import OrderedDict from functools import partial -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Tuple import torch from torch import nn, Tensor from torchvision.models._api import register_model, Weights, WeightsEnum from torchvision.models._meta import _IMAGENET_CATEGORIES from torchvision.models._utils import _ovewrite_named_param -from torchvision.models.mobilenetv2 import MobileNetV2 -from torchvision.utils import _log_api_usage_once +from torchvision.models.mobilenetv2 import InvertedResidual from torchvision.ops.misc import MLP from torchvision.transforms._presets import ImageClassification +from torchvision.utils import _log_api_usage_once -__all__ = ["MobileViT", "MobileViT_Weights", "MobileViT_V2_Weights"] +__all__ = ["MobileViT", "MobileViT_Weights", "MobileViT_V2_Weights", "MobileViTV2"] _COMMON_META = { "categories": _IMAGENET_CATEGORIES, } + # For V1, we have 3 sets of weights xx_small (1.3M parameters), x_small (2.3M parameters), and small (5.6M parameters) # For V2, we have one set of weights. # Paper link: v1 https://arxiv.org/abs/2110.02178. # Paper link: v2 https://arxiv.org/pdf/2206.02680.pdf. # v2 (what the difference with the V1 paper?) # Things to be done: write the V1, MobileViTblock, MobileViTV2block, weights (for V1 and V2), documentation... -# TODO: What about multi-scale sampler? Check later... +# TODO: What about multi-scale sampler? Check later for V2 training... class MobileViT_Weights(WeightsEnum): @@ -106,7 +107,7 @@ class TransformerEncoderBlock(nn.Module): def __init__( self, num_heads: int, - hidden_dim: int, + hidden_dim: int, # This is the embedding dim (known as E or d), should be a multiple of num_heads... mlp_dim: int, dropout: float, attention_dropout: float, @@ -114,14 +115,13 @@ def __init__( ): super().__init__() self.num_heads = num_heads - - # Attention block self.ln_1 = norm_layer(hidden_dim) self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) self.dropout = nn.Dropout(dropout) # MLP block (inspired from swin_transformer.py) - self.mlp = MLP(mlp_dim, [hidden_dim, mlp_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + # TODO: Rename the hidden_dim variable... + self.mlp = MLP(hidden_dim, [mlp_dim, hidden_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) for m in self.mlp.modules(): if isinstance(m, nn.Linear): @@ -130,12 +130,13 @@ def __init__( nn.init.normal_(m.bias, std=1e-6) def forward(self, input: torch.Tensor): + # B x N x D torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x = self.dropout(x) x = x + input - y = self.mlp(y) + y = self.mlp(x) return x + y @@ -144,22 +145,18 @@ class TransformerEncoder(nn.Module): def __init__( self, - seq_length: int, - num_layers: int, - num_heads: int, - hidden_dim: int, + num_layers: int, # This is the depth... Okay... + num_heads: int, # This is number of heads in the multi-attention layer... Okay ... + hidden_dim: int, # This is the embedding or d dimension, should be a multiple of num_heads... mlp_dim: int, dropout: float, attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() - # Note that batch_size is on the first dim because - # we have batch_first=True in nn.MultiAttention() by default - self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT self.dropout = nn.Dropout(dropout) layers: OrderedDict[str, nn.Module] = OrderedDict() - # Multiple + # Multiple iteration over the num_layers/depth... for i in range(num_layers): layers[f"encoder_layer_{i}"] = TransformerEncoderBlock( num_heads, @@ -170,61 +167,114 @@ def __init__( norm_layer, ) self.layers = nn.Sequential(layers) - self.ln = norm_layer(hidden_dim) - - def forward(self, input: torch.Tensor): - torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") - input = input + self.pos_embedding - return self.ln(self.layers(self.dropout(input))) - -# TODO: We will need a mobilenet block as well. -# TODO: We need to use a Transformer. In progress... Using the one from TorchVision... -# TODO: We need a LayerNorm as well...In progress... + def forward(self, x: torch.Tensor): + tensors = [] + # Here we loop over the P pixels of the + # tensor x of shape: B, P, N, d + for p in range(x.shape[1]): + tmp_tensor = self.layers(x[:, p, :, :]) + # Adding back the patch dimension before concatenating + tmp_tensor = tmp_tensor.unsqueeze(1) + tensors.append(tmp_tensor) + return torch.cat(tensors, dim=1) class MobileViTBlock(nn.Module): - def __init__(self, dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.0, - attention_dropout=0.5): + def __init__( + self, + dim, + depth, + channel, + kernel_size, + patch_size: Tuple[int, int], + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.5, + ): super().__init__() - self.patch_height, self.patch_width = patch_dimensions + self.patch_size = patch_size self.conv1 = nn.Sequential( - nn.Conv2d(channel, channel, kernel_size, 1, bias=False), nn.BatchNorm2d(channel), nn.SiLU() + nn.Conv2d(channel, channel, kernel_size, 1, 1, bias=False), nn.BatchNorm2d(channel), nn.SiLU() ) - # Point-wise convolution (1 x 1) self.conv2 = nn.Sequential(nn.Conv2d(channel, dim, 1, 1, 0, bias=False), nn.BatchNorm2d(dim), nn.SiLU()) - # TODO: Setup the inputs... - self.transformer = TransformerEncoder(dim, depth, 4, 8, mlp_dim, dropout, attention_dropout) - + num_heads = 4 + self.transformer = TransformerEncoder(depth, num_heads, dim, mlp_dim, dropout, attention_dropout) self.conv3 = nn.Sequential(nn.Conv2d(dim, channel, 1, 1, 0, bias=False), nn.BatchNorm2d(channel), nn.SiLU()) self.conv4 = nn.Sequential( - nn.Conv2d(2 * channel, channel, kernel_size, 1, bias=False), nn.BatchNorm2d(channel), nn.SiLU() + nn.Conv2d(2 * channel, channel, kernel_size, 1, 1, bias=False), nn.BatchNorm2d(channel), nn.SiLU() ) + @staticmethod + def _unfold(x: Tensor, patch_size: Tuple[int, int], n_patches: Tuple[int, int]) -> Tensor: + """ + Unfold a batch of B image tensors B x d x H X W into a batch of B P x N x d tensors + (N is the number of patches) + These P x N x d tensors are then used by the transformer encoder where d is the hidden + dimension/encoding, N is the sequence length and we loop over the pixels P. + """ + h_patch, w_patch = patch_size + n_h_patch, n_w_patch = n_patches + # P is the number of pixels + P = h_patch * w_patch + B, d, _, _ = x.shape + N = n_w_patch * n_h_patch + + # We reshape from B x d x H x W to (B * d * n_h_patch) x h_patch x n_w_patch x w_patch + x = x.reshape(B * d * n_w_patch, h_patch, n_h_patch, w_patch) + # Then we transpose (B * d * n_h_patch) x h_patch x n_w_patch x w_patch into (B * d * n_h_patch) x n_w_patch x h_patch x w_patch + x = x.transpose(1, 2) + # Next, we reshape (B * d * n_h_patch) x n_w_patch x h_patch x w_patch into B x d x N x P + x = x.reshape(B, d, N, P) + # And we finish by transposing B x d x N x P into B x P x N x d + x = x.transpose(1, 3) + return x + + @staticmethod + def _fold(x: Tensor, patch_size: Tuple[int, int], n_patches: Tuple[int, int]) -> Tensor: + """ + Fold a batch of B P x N x d tensors + (N is the number of patches) into a batch of B d x H x W image tensors. + This is the reverse operation of unfold. + """ + h_patch, w_patch = patch_size + n_h_patch, n_w_patch = n_patches + B, _, _, d = x.shape + x = x.transpose(1, 3) + + x = x.reshape(B * d * n_h_patch, n_w_patch, h_patch, w_patch) + x = x.transpose(1, 2) + x = x.reshape(B, d, n_h_patch * h_patch, n_w_patch * w_patch) + return x + def forward(self, x): + # We compute how many patches along the width patch dimension, the height patch dimension, + # and the total number of patches. + # The number of patches N x the numbre of pixels P in a patch + # is equal to the image area H x W. + _, _, H, W = x.shape + h_patch, w_patch = self.patch_size + n_w_patch = W // w_patch + n_h_patch = H // h_patch + n_patches = (n_h_patch, n_w_patch) y = x.detach().clone() x = self.conv1(x) x = self.conv2(x) - # batch, channels, height, width. - _, _, h, w = x.shape - # This is the unfloding (from spatial features to patches) and folding (from patches back to features) parts. - # TODO: What are the values of self.ph and self.pw. - # TODO: Change with a PyTorch operation... In progress... - print(x.shape) - """ - x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) + x = self._unfold(x, patch_size=self.patch_size, n_patches=n_patches) + # We get a tensor of shape: B x P x N x d after the previous steps x = self.transformer(x) - # The reverse operation... - x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) + # The transformer blocks keep the B x P x N x d shape + x = self._fold(x, patch_size=self.patch_size, n_patches=n_patches) + # We get back B x d x H x W tensors x = self.conv3(x) + # Then we get the inital shape B x C x H X W x = torch.cat((x, y), 1) x = self.conv4(x) - """ return x -# Separable self-attention -# TODO: Is this necessary? Check... Maybe +# TODO: Is this separable self-attention? +# TODO: Is this necessary? Probably for the V2 version... class MobileViTV2Block(MobileViTBlock): def forward(self, x: Tensor): return x @@ -234,48 +284,94 @@ class MobileViT(nn.Module): """ Implements MobileViT from the `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `_ paper. Args: - TODO: Arguments to be updated... in progress num_classes (int): Number of classes for classification head. Default: 1000. - layers_conf (dict): The layers configuration. + d (List[int]): A list of the layers' dimensions. + c (List[int]): A list of the layers' channels. + expand_ratio (int): The expansion ratio of the InvertedResidual block. Default: 4. """ - def __init__( - self, - # Trained on ImageNet1K by default. - num_classes: int = 1000, - layers_conf: dict = None, - # TODO: Should this be optional? Yes probably... - block: Optional[Callable[..., nn.Module]] = None, - ): + def __init__(self, num_classes: int = 1000, d: List[int] = None, c: List[int] = None, expand_ratio: int = 4): super().__init__() _log_api_usage_once(self) - # TODO: Add blocks... In progress... + if len(d) != 3: + raise ValueError(f"d should be non-empty list, got {d}") + if len(c) != 11: + raise ValueError(f"c should be non-empty list, got {c}") self.num_classes = num_classes - - if block is None: - block = MobileViTBlock - # Build the model one layer at a time. - layers: List[nn.Module] = [] + self.expand_ratio = expand_ratio + # n x n convolution as an input layer + # 3 is the number of RGB channels thus it is the + # input dimension. + self.conv_first = nn.Sequential(nn.Conv2d(3, c[0], 3, 2, 1, bias=False), nn.BatchNorm2d(c[0]), nn.SiLU()) + self.transformer_depths = [2, 4, 3] + layers = [ + InvertedResidual(inp=c[0], oup=c[1], stride=1, expand_ratio=self.expand_ratio), + InvertedResidual(inp=c[1], oup=c[2], stride=2, expand_ratio=self.expand_ratio), + # Twice the same block used here. + InvertedResidual(inp=c[2], oup=c[3], stride=1, expand_ratio=self.expand_ratio), + InvertedResidual(inp=c[2], oup=c[3], stride=1, expand_ratio=self.expand_ratio), + InvertedResidual(inp=c[3], oup=c[4], stride=2, expand_ratio=self.expand_ratio), + MobileViTBlock( + dim=d[0], + channel=c[5], + depth=self.transformer_depths[0], + kernel_size=3, + patch_size=(2, 2), + mlp_dim=d[0] * 2, + ), + InvertedResidual(inp=c[5], oup=c[6], stride=2, expand_ratio=self.expand_ratio), + MobileViTBlock( + dim=d[1], + channel=c[7], + depth=self.transformer_depths[1], + kernel_size=3, + patch_size=(2, 2), + mlp_dim=d[1] * 4, + ), + InvertedResidual(inp=c[7], oup=c[8], stride=2, expand_ratio=self.expand_ratio), + MobileViTBlock( + dim=d[2], + channel=c[9], + depth=self.transformer_depths[2], + kernel_size=3, + patch_size=(2, 2), + mlp_dim=d[2] * 4, + ), + ] self.features = nn.Sequential(*layers) + # height // 32 gives 8 for height 256... + self.avgpool = nn.AvgPool2d(8, 1) + # 1 x 1 convolution as an output layer (before fc) + self.conv_last = nn.Sequential(nn.Conv2d(c[9], c[10], 1, 1, 0, bias=False), nn.BatchNorm2d(c[10]), nn.SiLU()) + self.fc = nn.Linear(c[10], self.num_classes) - # TODO: This is the core thing to implement... def forward(self, x): + x = self.conv_first(x) x = self.features(x) + x = self.avgpool(x) + x = self.conv_last(x) + x = torch.flatten(x, 1) + x = self.fc(x) return x def _mobile_vit( - # TODO: Update the parameters... + num_classes: int, + d: List[int], + c: List[int], weights: Optional[WeightsEnum], progress: bool, + expand_ratio: int = 4, **kwargs: Any, ) -> MobileViT: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = MobileViT( - # TODO: Update these...Will pass different configurations depending on the size of the mdoel... - # In progress... + num_classes=num_classes, + c=c, + d=d, + expand_ratio=expand_ratio, **kwargs, ) @@ -294,7 +390,7 @@ def mobile_vit_s(*, weights: Optional[MobileViT_Weights] = None, progress: bool Args: weights (:class:`~torchvision.models.MobileViT_Weights`, optional): The pretrained weights to use. See - :class:`~torchvision.models.Swin_V2_B_Weights` below for + :class:`~torchvision.models.MobileViT_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the @@ -307,32 +403,49 @@ def mobile_vit_s(*, weights: Optional[MobileViT_Weights] = None, progress: bool .. autoclass:: torchvision.models.MobileViT_Weights :members: """ + s_c = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640] + s_d = [144, 192, 240] weights = MobileViT_Weights.verify(weights) - return _mobile_vit(weights=weights) + return _mobile_vit(c=s_c, d=s_d, weights=weights, progress=progress, **kwargs) @register_model() -def mobile_vit_xs(): +def mobile_vit_xs(*, weights: Optional[MobileViT_Weights] = None, progress: bool = True, **kwargs: Any): + # TODO: Add the documentation + xs_c = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] + xs_d = [96, 120, 144] weights = MobileViT_XS_Weights.verify(weights) - return _mobile_vit(weights=weights) + return _mobile_vit(c=xs_c, d=xs_d, weights=weights, progress=progress, **kwargs) @register_model() -def mobile_vit_xxs(): +def mobile_vit_xxs(*, weights: Optional[MobileViT_Weights] = None, progress: bool = True, **kwargs: Any): + # TODO: Add the documentation + xxs_c = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320] + xxs_d = [64, 80, 96] weights = MobileViT_XXS_Weights.verify(weights) - return _mobile_vit(weights=weights) + return _mobile_vit(c=xxs_c, d=xxs_d, weights=weights, progress=progress, expand_ratio=2, **kwargs) + + +# TODO: Implement this... +def _mobile_vit_v2(): + pass @register_model() def mobile_vit_v2(): + # TODO: Finish and add documentation. weights = MobileViT_V2_Weights.verify(weights) - return _mobile_vit(weights=weights) + return _mobile_vit_v2(weights=weights) if __name__ == "__main__": - # dim, depth, channel, kernel_size, patch_dimensions, mlp_dim, dropout=0.0 - block = MobileViTBlock(dim=1, depth=3, channel=3, kernel_size=3, patch_dimensions=(2, 2), - mlp_dim=2, dropout=0.5) - print(block) - x = torch.rand(3, 3, 3, 3) - print(block(x)) \ No newline at end of file + block = MobileViTBlock(dim=8 * 10, depth=1, channel=3, kernel_size=3, patch_size=(2, 2), mlp_dim=2, dropout=0.5) + # B x C x H x W + x = torch.rand(10, 3, 10, 10) + assert block(x).shape == (10, 3, 10, 10) + + # Batch of 10 RGB (256 x 256) random images + img = torch.randn(10, 3, 256, 256) + model = mobile_vit_s(num_classes=1000) + assert model(img).shape == (10, 1000) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index a0a42ab07b7..21aa1f7a23c 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -108,6 +108,8 @@ def __init__( self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) def forward(self, input: torch.Tensor): + # By default, these values are (*, 32, 512) for the original Transformer paper... + # So need some reshaping to get to this shape. torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index d4bda7decc5..7a9c1590515 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -286,7 +286,7 @@ def __init__( # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} - + print(in_channels, hidden_channels) layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: From c373255a95123e1f6c20a88bab67f6737f0fdcef Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 20 Nov 2022 18:12:00 +0100 Subject: [PATCH 08/15] [CI/CD] Fix URL of flake8 in pre-commits. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 463a97359ab..e8dce60467d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: - black == 22.3.0 - usort == 1.0.2 - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 3.9.2 hooks: - id: flake8 From a6dd59903ed7c5e1649eb8421d645980084592ad Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 20 Nov 2022 18:13:35 +0100 Subject: [PATCH 09/15] Remove changes made by mistake. --- torchvision/models/vision_transformer.py | 2 -- torchvision/ops/misc.py | 1 - 2 files changed, 3 deletions(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 21aa1f7a23c..a0a42ab07b7 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -108,8 +108,6 @@ def __init__( self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) def forward(self, input: torch.Tensor): - # By default, these values are (*, 32, 512) for the original Transformer paper... - # So need some reshaping to get to this shape. torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 7a9c1590515..a62c00a4765 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -286,7 +286,6 @@ def __init__( # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} - print(in_channels, hidden_channels) layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: From cff346ec46a20e56dae71a43918088ca0c484a3d Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 27 Nov 2022 14:53:13 +0100 Subject: [PATCH 10/15] Few changes thanks to code review. --- torchvision/models/mobilevit.py | 166 ++++++++++------------- torchvision/models/vision_transformer.py | 1 - 2 files changed, 69 insertions(+), 98 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index 75c686ae3a5..dabfc3407d3 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -15,88 +15,13 @@ from torchvision.transforms._presets import ImageClassification from torchvision.utils import _log_api_usage_once -__all__ = ["MobileViT", "MobileViT_Weights", "MobileViT_V2_Weights", "MobileViTV2"] +__all__ = ["MobileViT", "MobileViT_Weights"] _COMMON_META = { "categories": _IMAGENET_CATEGORIES, } -# For V1, we have 3 sets of weights xx_small (1.3M parameters), x_small (2.3M parameters), and small (5.6M parameters) -# For V2, we have one set of weights. -# Paper link: v1 https://arxiv.org/abs/2110.02178. -# Paper link: v2 https://arxiv.org/pdf/2206.02680.pdf. -# v2 (what the difference with the V1 paper?) -# Things to be done: write the V1, MobileViTblock, MobileViTV2block, weights (for V1 and V2), documentation... -# TODO: What about multi-scale sampler? Check later for V2 training... - - -class MobileViT_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - # TODO: Update the URL once the model has been trained... - url="https://download.pytorch.org/models/mobilevit.pth", - transforms=partial(ImageClassification, crop_size=256), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", - "_metrics": { - # TODO: Update with the correct values. For now, these are the expected ones from the paper. - "ImageNet-1K": { - "acc@1": 78.4, - "acc@5": 94.1, - } - }, - "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MobileViT_XS_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - # TODO: Update the URL once the model has been trained... - url="https://download.pytorch.org/models/mobilevit_xs.pth", - transforms=partial(ImageClassification, crop_size=256), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", - "_metrics": { - # TODO: Update with the correct values. For now, these are the expected ones from the paper. - "ImageNet-1K": { - "acc@1": 74.8, - "acc@5": 92.3, - } - }, - "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MobileViT_XXS_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - # TODO: Update the URL once the model has been trained... - url="https://download.pytorch.org/models/mobilevit_xxs.pth", - transforms=partial(ImageClassification, crop_size=256), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", - "_metrics": { - # TODO: Update with the correct values. For now, these are the expected ones from the paper. - "ImageNet-1K": { - "acc@1": 69.0, - "acc@5": 88.9, - } - }, - "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", - }, - ) - DEFAULT = IMAGENET1K_V1 - - -# TODO: Take inspiration from the V1 weights... In progress... -class MobileViT_V2_Weights(WeightsEnum): - pass # The EncoderBlock and Encoder from vision_transformer.py @@ -193,6 +118,7 @@ def __init__( attention_dropout: float = 0.5, ): super().__init__() + _log_api_usage_once(self) self.patch_size = patch_size self.conv1 = nn.Sequential( nn.Conv2d(channel, channel, kernel_size, 1, 1, bias=False), nn.BatchNorm2d(channel), nn.SiLU() @@ -273,13 +199,6 @@ def forward(self, x): return x -# TODO: Is this separable self-attention? -# TODO: Is this necessary? Probably for the V2 version... -class MobileViTV2Block(MobileViTBlock): - def forward(self, x: Tensor): - return x - - class MobileViT(nn.Module): """ Implements MobileViT from the `"MobileViT: Light-Weight, General-Purpose, and Mobile-Friendly Vision Transfomer" `_ paper. @@ -343,15 +262,16 @@ def __init__(self, num_classes: int = 1000, d: List[int] = None, c: List[int] = self.avgpool = nn.AvgPool2d(8, 1) # 1 x 1 convolution as an output layer (before fc) self.conv_last = nn.Sequential(nn.Conv2d(c[9], c[10], 1, 1, 0, bias=False), nn.BatchNorm2d(c[10]), nn.SiLU()) - self.fc = nn.Linear(c[10], self.num_classes) + self.classifier = nn.Sequential( + nn.Flatten(1), nn.Linear(c[10], self.num_classes) + ) def forward(self, x): x = self.conv_first(x) x = self.features(x) x = self.avgpool(x) x = self.conv_last(x) - x = torch.flatten(x, 1) - x = self.fc(x) + x = self.classifier(x) return x @@ -381,6 +301,69 @@ def _mobile_vit( return model +class MobileViT_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # TODO: Update the URL once the model has been trained... + url="https://download.pytorch.org/models/mobilevit.pth", + transforms=partial(ImageClassification, crop_size=256), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", + "_metrics": { + # TODO: Update with the correct values. For now, these are the expected ones from the paper. + "ImageNet-1K": { + "acc@1": 78.4, + "acc@5": 94.1, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MobileViT_XS_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # TODO: Update the URL once the model has been trained... + url="https://download.pytorch.org/models/mobilevit_xs.pth", + transforms=partial(ImageClassification, crop_size=256), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", + "_metrics": { + # TODO: Update with the correct values. For now, these are the expected ones from the paper. + "ImageNet-1K": { + "acc@1": 74.8, + "acc@5": 92.3, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MobileViT_XXS_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # TODO: Update the URL once the model has been trained... + url="https://download.pytorch.org/models/mobilevit_xxs.pth", + transforms=partial(ImageClassification, crop_size=256), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilevit", + "_metrics": { + # TODO: Update with the correct values. For now, these are the expected ones from the paper. + "ImageNet-1K": { + "acc@1": 69.0, + "acc@5": 88.9, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + @register_model() def mobile_vit_s(*, weights: Optional[MobileViT_Weights] = None, progress: bool = True, **kwargs: Any): """ @@ -427,17 +410,6 @@ def mobile_vit_xxs(*, weights: Optional[MobileViT_Weights] = None, progress: boo return _mobile_vit(c=xxs_c, d=xxs_d, weights=weights, progress=progress, expand_ratio=2, **kwargs) -# TODO: Implement this... -def _mobile_vit_v2(): - pass - - -@register_model() -def mobile_vit_v2(): - # TODO: Finish and add documentation. - weights = MobileViT_V2_Weights.verify(weights) - return _mobile_vit_v2(weights=weights) - if __name__ == "__main__": block = MobileViTBlock(dim=8 * 10, depth=1, channel=3, kernel_size=3, patch_size=(2, 2), mlp_dim=2, dropout=0.5) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index a0a42ab07b7..9e6f68cfb3d 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -113,7 +113,6 @@ def forward(self, input: torch.Tensor): x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x = self.dropout(x) x = x + input - y = self.ln_2(x) y = self.mlp(y) return x + y From e58898206be1a9104b8c1d501ecf73f75a492924 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 27 Nov 2022 14:55:15 +0100 Subject: [PATCH 11/15] More changes thanks to code review. --- torchvision/models/mobilevit.py | 7 +------ torchvision/ops/misc.py | 1 + 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torchvision/models/mobilevit.py b/torchvision/models/mobilevit.py index dabfc3407d3..8b5aaa5608c 100644 --- a/torchvision/models/mobilevit.py +++ b/torchvision/models/mobilevit.py @@ -22,8 +22,6 @@ } - - # The EncoderBlock and Encoder from vision_transformer.py # TODO: Maybe refactor later... class TransformerEncoderBlock(nn.Module): @@ -262,9 +260,7 @@ def __init__(self, num_classes: int = 1000, d: List[int] = None, c: List[int] = self.avgpool = nn.AvgPool2d(8, 1) # 1 x 1 convolution as an output layer (before fc) self.conv_last = nn.Sequential(nn.Conv2d(c[9], c[10], 1, 1, 0, bias=False), nn.BatchNorm2d(c[10]), nn.SiLU()) - self.classifier = nn.Sequential( - nn.Flatten(1), nn.Linear(c[10], self.num_classes) - ) + self.classifier = nn.Sequential(nn.Flatten(1), nn.Linear(c[10], self.num_classes)) def forward(self, x): x = self.conv_first(x) @@ -410,7 +406,6 @@ def mobile_vit_xxs(*, weights: Optional[MobileViT_Weights] = None, progress: boo return _mobile_vit(c=xxs_c, d=xxs_d, weights=weights, progress=progress, expand_ratio=2, **kwargs) - if __name__ == "__main__": block = MobileViTBlock(dim=8 * 10, depth=1, channel=3, kernel_size=3, patch_size=(2, 2), mlp_dim=2, dropout=0.5) # B x C x H x W diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index a62c00a4765..80b899c935b 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -286,6 +286,7 @@ def __init__( # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} + layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: From 2a5089ee4c4fd688bc1ec6698cdd6bcc119bbd83 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 27 Nov 2022 14:57:07 +0100 Subject: [PATCH 12/15] Formatting. --- torchvision/ops/misc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 80b899c935b..a62c00a4765 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -286,7 +286,6 @@ def __init__( # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} - layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: From 1424f9a4f0a05841e214101e7cb771e2a4d1a90f Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 27 Nov 2022 14:58:07 +0100 Subject: [PATCH 13/15] Update misc.py --- torchvision/ops/misc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index a62c00a4765..cee7fee9596 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -287,6 +287,7 @@ def __init__( # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} layers = [] + in_dim = in_channels for hidden_dim in hidden_channels[:-1]: layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) From c34c1e6dc5bc97271fb6c44587e5bed94cc40dcc Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Sun, 27 Nov 2022 15:05:03 +0100 Subject: [PATCH 14/15] Remove useless change. --- torchvision/models/vision_transformer.py | 1 + torchvision/ops/misc.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 9e6f68cfb3d..a0a42ab07b7 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -113,6 +113,7 @@ def forward(self, input: torch.Tensor): x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x = self.dropout(x) x = x + input + y = self.ln_2(x) y = self.mlp(y) return x + y diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index cee7fee9596..d4bda7decc5 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -286,8 +286,8 @@ def __init__( # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} + layers = [] - in_dim = in_channels for hidden_dim in hidden_channels[:-1]: layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) From c7a54d59477249063dc64ee9b984d75e6cceed84 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Fri, 30 Dec 2022 08:58:51 +0100 Subject: [PATCH 15/15] Add .sh script to train on cloud. --- run.sh | 1 + 1 file changed, 1 insertion(+) create mode 100644 run.sh diff --git a/run.sh b/run.sh new file mode 100644 index 00000000000..72397557b81 --- /dev/null +++ b/run.sh @@ -0,0 +1 @@ +python setup.py develop && torchrun --nproc_per_node=8 vision-1/references/classification/train.py --model mobile_vit_xxs \ No newline at end of file