-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding MViT architecture #6105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Adding MViT architecture #6105
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 tasks
8cdee2f
to
1330d9c
Compare
beffe88
to
683d520
Compare
Here is the approach I used to verify that the implementation works exactly as expected. import collections
import torch
from pytorchvideo.models.vision_transformers import create_multiscale_vision_transformers
from torchvision.models.video import mvitv2 as TorchVision
class PyTorchVideo:
@staticmethod
def mvit_v2_t(**kwargs):
return create_multiscale_vision_transformers(
spatial_size=(224, 224),
temporal_size=16,
depth=10,
embed_dim_mul=[[1, 2.0], [3, 2.0], [8, 2.0]],
atten_head_mul=[[1, 2.0], [3, 2.0], [8, 2.0]],
pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 2, 2], [2, 1, 1, 1], [3, 1, 2, 2], [4, 1, 1, 1], [5, 1, 1, 1],
[6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 2, 2], [9, 1, 1, 1]],
droppath_rate_block=0.1,
# additional params for PyTorch Video
residual_pool=True,
separate_qkv=False,
pool_kv_stride_adaptive=(1, 8, 8),
pool_kvq_kernel=(3, 3, 3),
head_dropout_rate=0.0,
**kwargs,
)
@staticmethod
def mvit_v2_s(**kwargs):
return create_multiscale_vision_transformers(
spatial_size=(224, 224),
temporal_size=16,
depth=16,
embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 2, 2], [2, 1, 1, 1], [3, 1, 2, 2], [4, 1, 1, 1], [5, 1, 1, 1],
[6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1],
[12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 2, 2], [15, 1, 1, 1]],
droppath_rate_block=0.1,
# additional params for PyTorch Video
residual_pool=True,
separate_qkv=False,
pool_kv_stride_adaptive=(1, 8, 8),
pool_kvq_kernel=(3, 3, 3),
head_dropout_rate=0.0,
**kwargs,
)
@staticmethod
def mvit_v2_b(**kwargs):
return create_multiscale_vision_transformers(
spatial_size=(224, 224),
temporal_size=32,
depth=24,
embed_dim_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
atten_head_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 1, 1], [2, 1, 2, 2], [3, 1, 1, 1], [4, 1, 1, 1], [5, 1, 2, 2],
[6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1],
[12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 1, 1], [15, 1, 1, 1], [16, 1, 1, 1], [17, 1, 1, 1],
[18, 1, 1, 1], [19, 1, 1, 1], [20, 1, 1, 1], [21, 1, 2, 2], [22, 1, 1, 1], [23, 1, 1, 1]],
droppath_rate_block=0.3,
# additional params for PyTorch Video
residual_pool=True,
separate_qkv=False,
pool_kv_stride_adaptive=(1, 8, 8),
pool_kvq_kernel=(3, 3, 3),
head_dropout_rate=0.0,
**kwargs,
)
def ptv_to_tv_weights(state_dict):
d = dict(state_dict)
# remapping keys
mapping = collections.OrderedDict([
("patch_embed.patch_model.weight", "conv_proj.weight"),
("patch_embed.patch_model.bias", "conv_proj.bias"),
("cls_positional_encoding.cls_token", "pos_encoding.class_token"),
("cls_positional_encoding.pos_embed_spatial", "pos_encoding.spatial_pos"),
("cls_positional_encoding.pos_embed_temporal", "pos_encoding.temporal_pos"),
("cls_positional_encoding.pos_embed_class", "pos_encoding.class_pos"),
("attn.proj.weight", "attn.project.0.weight"),
("attn.proj.bias", "attn.project.0.bias"),
("attn.pool_q.weight", "attn.pool_q.pool.weight"),
("attn.norm_q.weight", "attn.pool_q.norm_act.0.weight"),
("attn.norm_q.bias", "attn.pool_q.norm_act.0.bias"),
("attn.pool_k.weight", "attn.pool_k.pool.weight"),
("attn.norm_k.weight", "attn.pool_k.norm_act.0.weight"),
("attn.norm_k.bias", "attn.pool_k.norm_act.0.bias"),
("attn.pool_v.weight", "attn.pool_v.pool.weight"),
("attn.norm_v.weight", "attn.pool_v.norm_act.0.weight"),
("attn.norm_v.bias", "attn.pool_v.norm_act.0.bias"),
("mlp.fc1.weight", "mlp.0.weight"),
("mlp.fc1.bias", "mlp.0.bias"),
("mlp.fc2.weight", "mlp.3.weight"),
("mlp.fc2.bias", "mlp.3.bias"),
("norm_embed.weight", "norm.weight"),
("norm_embed.bias", "norm.bias"),
("head.proj.weight", "head.0.weight"),
("head.proj.bias", "head.0.bias"),
("proj.weight", "project.weight"),
("proj.bias", "project.bias"),
])
for old_key in list(d.keys()):
for pattern, replacement in mapping.items():
if pattern in old_key:
new_key = old_key.replace(pattern, replacement)
d[new_key] = d.pop(old_key)
break
# matching dimensions
d["pos_encoding.class_token"] = d["pos_encoding.class_token"][0, 0, :]
d["pos_encoding.spatial_pos"] = d["pos_encoding.spatial_pos"][0, :]
d["pos_encoding.temporal_pos"] = d["pos_encoding.temporal_pos"][0, :]
d["pos_encoding.class_pos"] = d["pos_encoding.class_pos"][0, 0, :]
# removing unnecessary keys
for old_key in list(d.keys()):
if "attn._attention_pool_" in old_key:
del d[old_key]
return d
def compare_models(ptv_model_fn, tv_model_fn, input_shape):
print(tv_model_fn.__name__)
x = torch.randn(input_shape)
ptv_m = ptv_model_fn().eval()
exp_result = ptv_m(x).sum()
d = ptv_m.state_dict()
d = ptv_to_tv_weights(d)
tv_m = tv_model_fn().eval()
tv_m.load_state_dict(d)
result = tv_m(x).sum()
torch.testing.assert_close(result, exp_result, rtol=0, atol=1e-6)
print("OK")
compare_models(PyTorchVideo.mvit_v2_t, TorchVision.mvit_v2_t, (1, 3, 16, 224, 224))
compare_models(PyTorchVideo.mvit_v2_s, TorchVision.mvit_v2_s, (1, 3, 16, 224, 224))
compare_models(PyTorchVideo.mvit_v2_b, TorchVision.mvit_v2_b, (1, 3, 32, 224, 224)) Prints:
We used PyTorchVideo latest main at the time of writing (git hash
|
Here is the approach for confirming that this implementation doesn't introduce any performance regressions: # Code from earlier comment goes here
import time
def benchmark(model_fn, input_shape, device, n=5, warmup=0.1):
torch.manual_seed(42)
m = model_fn().to(device).eval()
x = torch.randn(input_shape).to(device)
s = []
for i in range(n):
start = time.time()
m(x)
t = time.time() - start
if i > n * warmup:
s.append(t)
print(model_fn.__name__, torch.tensor(s).median())
device = "cuda"
batch_size = 4
n = 100
print(f"device={device}, batch_size={batch_size}, n={n}")
for name, backend in [("TorchVision", TorchVision), ("PyTorchVideo", PyTorchVideo)]:
print(name)
benchmark(backend.mvit_v2_t, (batch_size, 3, 16, 224, 224), device, n=n)
benchmark(backend.mvit_v2_s, (batch_size, 3, 16, 224, 224), device, n=n)
benchmark(backend.mvit_v2_b, (batch_size, 3, 32, 224, 224), device, n=n) Prints on an A100 (less is better):
|
This was referenced Jun 16, 2022
Merged
datumbox
added a commit
that referenced
this pull request
Jun 24, 2022
* Adding MViT v2 architecture (#6105) * Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations. * Add support to MViT v1 (#6179) * Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies. * Fix documentation * Remove unnecessary expected file. * Skip big model test * Rewrite the configuration logic to reduce LOC. * Fix mypy
facebook-github-bot
pushed a commit
that referenced
this pull request
Jun 27, 2022
Summary: * Adding MViT v2 architecture (#6105) * Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations. * Add support to MViT v1 (#6179) * Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies. * Fix documentation * Remove unnecessary expected file. * Skip big model test * Rewrite the configuration logic to reduce LOC. * Fix mypy Reviewed By: NicolasHug Differential Revision: D37450352 fbshipit-source-id: 5c0bf1065351d8dd612012902117fd866db02899
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds MViT to TorchVision. Unlike #6086, this is a rewrite of the implementation using TorchVision's idioms and building blocks.
This PR:
Based on work of @haooooooqi, @feichtenhofer and @lyttonhao on PyTorchVideo.
Note: The specific implementation doesn't contain all the improvements described on the MViTv2 paper (mainly the relative positional embeddings and the position of the downsample layer). This options will be added on a followup PR.