Skip to content

Adding MViT architecture #6105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 7, 2022
Merged

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented May 27, 2022

This PR adds MViT to TorchVision. Unlike #6086, this is a rewrite of the implementation using TorchVision's idioms and building blocks.

This PR:

  • Adds the implementation and main variants of MViT
  • Adds Documentation
  • Verifies results and performance against PyTorch Video's implementation

Based on work of @haooooooqi, @feichtenhofer and @lyttonhao on PyTorchVideo.

Note: The specific implementation doesn't contain all the improvements described on the MViTv2 paper (mainly the relative positional embeddings and the position of the downsample layer). This options will be added on a followup PR.

@datumbox datumbox changed the title Adding MViT v2 architecture [WIP] Adding MViT v2 architecture May 27, 2022
@datumbox datumbox mentioned this pull request May 27, 2022
24 tasks
@datumbox datumbox force-pushed the models/mvit_rewrite branch from 8cdee2f to 1330d9c Compare May 30, 2022 12:06
@datumbox datumbox force-pushed the models/mvit_rewrite branch from beffe88 to 683d520 Compare May 30, 2022 12:19
@datumbox
Copy link
Contributor Author

datumbox commented May 31, 2022

Here is the approach I used to verify that the implementation works exactly as expected.

import collections
import torch

from pytorchvideo.models.vision_transformers import create_multiscale_vision_transformers
from torchvision.models.video import mvitv2 as TorchVision


class PyTorchVideo:
    @staticmethod
    def mvit_v2_t(**kwargs):
        return create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=16,
            depth=10,
            embed_dim_mul=[[1, 2.0], [3, 2.0], [8, 2.0]],
            atten_head_mul=[[1, 2.0], [3, 2.0], [8, 2.0]],
            pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 2, 2], [2, 1, 1, 1], [3, 1, 2, 2], [4, 1, 1, 1], [5, 1, 1, 1],
                                [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 2, 2], [9, 1, 1, 1]],
            droppath_rate_block=0.1,
            # additional params for PyTorch Video
            residual_pool=True,
            separate_qkv=False,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )

    @staticmethod
    def mvit_v2_s(**kwargs):
        return create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=16,
            depth=16,
            embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 2, 2], [2, 1, 1, 1], [3, 1, 2, 2], [4, 1, 1, 1], [5, 1, 1, 1],
                                [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1],
                                [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 2, 2], [15, 1, 1, 1]],
            droppath_rate_block=0.1,
            # additional params for PyTorch Video
            residual_pool=True,
            separate_qkv=False,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )

    @staticmethod
    def mvit_v2_b(**kwargs):
        return create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=32,
            depth=24,
            embed_dim_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
            atten_head_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
            pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 1, 1], [2, 1, 2, 2], [3, 1, 1, 1], [4, 1, 1, 1], [5, 1, 2, 2],
                                [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1],
                                [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 1, 1], [15, 1, 1, 1], [16, 1, 1, 1], [17, 1, 1, 1],
                                [18, 1, 1, 1], [19, 1, 1, 1], [20, 1, 1, 1], [21, 1, 2, 2], [22, 1, 1, 1], [23, 1, 1, 1]],
            droppath_rate_block=0.3,
            # additional params for PyTorch Video
            residual_pool=True,
            separate_qkv=False,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )


def ptv_to_tv_weights(state_dict):
    d = dict(state_dict)

    # remapping keys
    mapping = collections.OrderedDict([
        ("patch_embed.patch_model.weight", "conv_proj.weight"),
        ("patch_embed.patch_model.bias", "conv_proj.bias"),
        ("cls_positional_encoding.cls_token", "pos_encoding.class_token"),
        ("cls_positional_encoding.pos_embed_spatial", "pos_encoding.spatial_pos"),
        ("cls_positional_encoding.pos_embed_temporal", "pos_encoding.temporal_pos"),
        ("cls_positional_encoding.pos_embed_class", "pos_encoding.class_pos"),
        ("attn.proj.weight", "attn.project.0.weight"),
        ("attn.proj.bias", "attn.project.0.bias"),
        ("attn.pool_q.weight", "attn.pool_q.pool.weight"),
        ("attn.norm_q.weight", "attn.pool_q.norm_act.0.weight"),
        ("attn.norm_q.bias", "attn.pool_q.norm_act.0.bias"),
        ("attn.pool_k.weight", "attn.pool_k.pool.weight"),
        ("attn.norm_k.weight", "attn.pool_k.norm_act.0.weight"),
        ("attn.norm_k.bias", "attn.pool_k.norm_act.0.bias"),
        ("attn.pool_v.weight", "attn.pool_v.pool.weight"),
        ("attn.norm_v.weight", "attn.pool_v.norm_act.0.weight"),
        ("attn.norm_v.bias", "attn.pool_v.norm_act.0.bias"),
        ("mlp.fc1.weight", "mlp.0.weight"),
        ("mlp.fc1.bias", "mlp.0.bias"),
        ("mlp.fc2.weight", "mlp.3.weight"),
        ("mlp.fc2.bias", "mlp.3.bias"),
        ("norm_embed.weight", "norm.weight"),
        ("norm_embed.bias", "norm.bias"),
        ("head.proj.weight", "head.0.weight"),
        ("head.proj.bias", "head.0.bias"),
        ("proj.weight", "project.weight"),
        ("proj.bias", "project.bias"),
    ])
    for old_key in list(d.keys()):
        for pattern, replacement in mapping.items():
            if pattern in old_key:
                new_key = old_key.replace(pattern, replacement)
                d[new_key] = d.pop(old_key)
                break


    # matching dimensions
    d["pos_encoding.class_token"] = d["pos_encoding.class_token"][0, 0, :]
    d["pos_encoding.spatial_pos"] = d["pos_encoding.spatial_pos"][0, :]
    d["pos_encoding.temporal_pos"] = d["pos_encoding.temporal_pos"][0, :]
    d["pos_encoding.class_pos"] = d["pos_encoding.class_pos"][0, 0, :]

    # removing unnecessary keys
    for old_key in list(d.keys()):
        if "attn._attention_pool_" in old_key:
            del d[old_key]
    return d


def compare_models(ptv_model_fn, tv_model_fn, input_shape):
    print(tv_model_fn.__name__)
    x = torch.randn(input_shape)

    ptv_m = ptv_model_fn().eval()
    exp_result = ptv_m(x).sum()

    d = ptv_m.state_dict()
    d = ptv_to_tv_weights(d)

    tv_m = tv_model_fn().eval()
    tv_m.load_state_dict(d)
    result = tv_m(x).sum()

    torch.testing.assert_close(result, exp_result, rtol=0, atol=1e-6)
    print("OK")


compare_models(PyTorchVideo.mvit_v2_t, TorchVision.mvit_v2_t, (1, 3, 16, 224, 224))
compare_models(PyTorchVideo.mvit_v2_s, TorchVision.mvit_v2_s, (1, 3, 16, 224, 224))
compare_models(PyTorchVideo.mvit_v2_b, TorchVision.mvit_v2_b, (1, 3, 32, 224, 224))

Prints:

mvit_v2_t
OK
mvit_v2_s
OK
mvit_v2_b
OK

We used PyTorchVideo latest main at the time of writing (git hash 718d0a46a8ec3cfd12b823a2063ab66b006281b2), installed via:

pip install git+https://github.com/facebookresearch/pytorchvideo.git@main

@datumbox
Copy link
Contributor Author

datumbox commented Jun 1, 2022

Here is the approach for confirming that this implementation doesn't introduce any performance regressions:

# Code from earlier comment goes here

import time

def benchmark(model_fn, input_shape, device, n=5, warmup=0.1):
    torch.manual_seed(42)
    m = model_fn().to(device).eval()
    x = torch.randn(input_shape).to(device)

    s = []
    for i in range(n):
        start = time.time()
        m(x)
        t = time.time() - start
        if i > n * warmup:
            s.append(t)

    print(model_fn.__name__, torch.tensor(s).median())


device = "cuda"
batch_size = 4
n = 100

print(f"device={device}, batch_size={batch_size}, n={n}")
for name, backend in [("TorchVision", TorchVision), ("PyTorchVideo", PyTorchVideo)]:
    print(name)
    benchmark(backend.mvit_v2_t, (batch_size, 3, 16, 224, 224), device, n=n)
    benchmark(backend.mvit_v2_s, (batch_size, 3, 16, 224, 224), device, n=n)
    benchmark(backend.mvit_v2_b, (batch_size, 3, 32, 224, 224), device, n=n)

Prints on an A100 (less is better):

device=cuda, batch_size=4, n=100
TorchVision
mvit_v2_t tensor(0.0255)
mvit_v2_s tensor(0.0330)
mvit_v2_b tensor(0.1091)
PyTorchVideo
mvit_v2_t tensor(0.0266)
mvit_v2_s tensor(0.0345)
mvit_v2_b tensor(0.1124)

@datumbox datumbox changed the base branch from main to mvitv2 June 7, 2022 12:22
@datumbox datumbox changed the title [WIP] Adding MViT v2 architecture Adding MViT v2 architecture Jun 7, 2022
@datumbox datumbox merged commit 69095dd into pytorch:mvitv2 Jun 7, 2022
@datumbox datumbox deleted the models/mvit_rewrite branch June 7, 2022 12:28
datumbox added a commit that referenced this pull request Jun 24, 2022
* Adding MViT v2 architecture (#6105)

* Adding mvitv2 architecture

* Fixing memory issues on tests and minor refactorings.

* Adding input validation

* Adding docs and minor refactoring

* Add `min_temporal_size` in the supported meta-data.

* Switch Tuple[int, int, int] with List[int] to support easier the 2D case

* Adding more docs and references

* Change naming conventions of classes to follow the same pattern as MobileNetV3

* Fix test breakage.

* Update todos

* Performance optimizations.

* Add support to MViT v1 (#6179)

* Switch implementation to v1 variant.

* Fix docs

* Adding back a v2 pseudovariant

* Changing the way the network are configured.

* Temporarily removing v2

* Adding weights.

* Expand _squeeze/_unsqueeze to support arbitrary dims.

* Update references script.

* Fix tests.

* Fixing frames and preprocessing.

* Fix std/mean values in transforms.

* Add permanent Dropout and update the weights.

* Update accuracies.

* Fix documentation

* Remove unnecessary expected file.

* Skip big model test

* Rewrite the configuration logic to reduce LOC.

* Fix mypy
facebook-github-bot pushed a commit that referenced this pull request Jun 27, 2022
Summary:
* Adding MViT v2 architecture (#6105)

* Adding mvitv2 architecture

* Fixing memory issues on tests and minor refactorings.

* Adding input validation

* Adding docs and minor refactoring

* Add `min_temporal_size` in the supported meta-data.

* Switch Tuple[int, int, int] with List[int] to support easier the 2D case

* Adding more docs and references

* Change naming conventions of classes to follow the same pattern as MobileNetV3

* Fix test breakage.

* Update todos

* Performance optimizations.

* Add support to MViT v1 (#6179)

* Switch implementation to v1 variant.

* Fix docs

* Adding back a v2 pseudovariant

* Changing the way the network are configured.

* Temporarily removing v2

* Adding weights.

* Expand _squeeze/_unsqueeze to support arbitrary dims.

* Update references script.

* Fix tests.

* Fixing frames and preprocessing.

* Fix std/mean values in transforms.

* Add permanent Dropout and update the weights.

* Update accuracies.

* Fix documentation

* Remove unnecessary expected file.

* Skip big model test

* Rewrite the configuration logic to reduce LOC.

* Fix mypy

Reviewed By: NicolasHug

Differential Revision: D37450352

fbshipit-source-id: 5c0bf1065351d8dd612012902117fd866db02899
@datumbox datumbox changed the title Adding MViT v2 architecture Adding MViT architecture Sep 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants