Skip to content

Commit 39b2441

Browse files
authored
Merge branch 'main' into rendered-sst2-dataset
2 parents e6c95ad + f670152 commit 39b2441

10 files changed

+82
-26
lines changed

docs/source/models.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ You can construct a model with random weights by calling its constructor:
8888
vit_b_32 = models.vit_b_32()
8989
vit_l_16 = models.vit_l_16()
9090
vit_l_32 = models.vit_l_32()
91+
vit_h_14 = models.vit_h_14()
9192
9293
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
9394
These can be constructed by passing ``pretrained=True``:
@@ -460,6 +461,7 @@ VisionTransformer
460461
vit_b_32
461462
vit_l_16
462463
vit_l_32
464+
vit_h_14
463465

464466
Quantized Models
465467
----------------

hubconf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,5 @@
6363
vit_b_32,
6464
vit_l_16,
6565
vit_l_32,
66+
vit_h_14,
6667
)

test/common_utils.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,18 @@
44
import random
55
import shutil
66
import tempfile
7-
from distutils.util import strtobool
87

98
import numpy as np
10-
import pytest
119
import torch
1210
from PIL import Image
1311
from torchvision import io
1412

1513
import __main__ # noqa: 401
1614

1715

18-
def get_bool_env_var(name, *, exist_ok=False, default=False):
19-
value = os.getenv(name)
20-
if value is None:
21-
return default
22-
if exist_ok:
23-
return True
24-
return bool(strtobool(value))
25-
26-
27-
IN_CIRCLE_CI = get_bool_env_var("CIRCLECI")
28-
IN_RE_WORKER = get_bool_env_var("INSIDE_RE_WORKER", exist_ok=True)
29-
IN_FBCODE = get_bool_env_var("IN_FBCODE_TORCHVISION")
16+
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true"
17+
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
18+
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
3019
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3120
CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda."
3221

@@ -213,7 +202,3 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
213202
# scriptable function test
214203
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
215204
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
216-
217-
218-
def run_on_env_var(name, *, skip_reason=None, exist_ok=False, default=False):
219-
return pytest.mark.skipif(not get_bool_env_var(name, exist_ok=exist_ok, default=default), reason=skip_reason)
939 Bytes
Binary file not shown.

test/test_prototype_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import importlib
2+
import os
23

34
import pytest
45
import test_models as TM
56
import torch
6-
from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda
7+
from common_utils import cpu_and_gpu, needs_cuda
78
from torchvision.prototype import models
89
from torchvision.prototype.models._api import WeightsEnum, Weights
910
from torchvision.prototype.models._utils import handle_legacy_interface
1011

11-
run_if_test_with_prototype = run_on_env_var(
12-
"PYTORCH_TEST_WITH_PROTOTYPE",
13-
skip_reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
12+
run_if_test_with_prototype = pytest.mark.skipif(
13+
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
14+
reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
1415
)
1516

1617

test/test_prototype_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
from torchvision.prototype.utils._internal import sequence_to_str
3+
4+
5+
@pytest.mark.parametrize(
6+
("seq", "separate_last", "expected"),
7+
[
8+
([], "", ""),
9+
(["foo"], "", "'foo'"),
10+
(["foo", "bar"], "", "'foo', 'bar'"),
11+
(["foo", "bar"], "and ", "'foo' and 'bar'"),
12+
(["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'"),
13+
(["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'"),
14+
],
15+
)
16+
def test_sequence_to_str(seq, separate_last, expected):
17+
assert sequence_to_str(seq, separate_last=separate_last) == expected

torchvision/models/vision_transformer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"vit_b_32",
1616
"vit_l_16",
1717
"vit_l_32",
18+
"vit_h_14",
1819
]
1920

2021
model_urls = {
@@ -260,6 +261,8 @@ def _vision_transformer(
260261
)
261262

262263
if pretrained:
264+
if arch not in model_urls:
265+
raise ValueError(f"No checkpoint is available for model type '{arch}'!")
263266
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
264267
model.load_state_dict(state_dict)
265268

@@ -354,6 +357,26 @@ def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
354357
)
355358

356359

360+
def vit_h_14(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
361+
"""
362+
Constructs a vit_h_14 architecture from
363+
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
364+
365+
NOTE: Pretrained weights are not available for this model.
366+
"""
367+
return _vision_transformer(
368+
arch="vit_h_14",
369+
patch_size=14,
370+
num_layers=32,
371+
num_heads=16,
372+
hidden_dim=1280,
373+
mlp_dim=5120,
374+
pretrained=pretrained,
375+
progress=progress,
376+
**kwargs,
377+
)
378+
379+
357380
def interpolate_embeddings(
358381
image_size: int,
359382
patch_size: int,

torchvision/prototype/models/_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ def verify(cls, obj: Any) -> Any:
6060

6161
@classmethod
6262
def from_str(cls, value: str) -> "WeightsEnum":
63-
for k, v in cls.__members__.items():
64-
if k == value:
65-
return v
63+
if value in cls.__members__:
64+
return cls.__members__[value]
6665
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
6766

6867
def get_state_dict(self, progress: bool) -> OrderedDict:

torchvision/prototype/models/vision_transformer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
"ViT_B_32_Weights",
2020
"ViT_L_16_Weights",
2121
"ViT_L_32_Weights",
22+
"ViT_H_14_Weights",
2223
"vit_b_16",
2324
"vit_b_32",
2425
"vit_l_16",
2526
"vit_l_32",
27+
"vit_h_14",
2628
]
2729

2830

@@ -99,6 +101,11 @@ class ViT_L_32_Weights(WeightsEnum):
99101
default = ImageNet1K_V1
100102

101103

104+
class ViT_H_14_Weights(WeightsEnum):
105+
# Weights are not available yet.
106+
pass
107+
108+
102109
def _vision_transformer(
103110
patch_size: int,
104111
num_layers: int,
@@ -192,3 +199,19 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
192199
progress=progress,
193200
**kwargs,
194201
)
202+
203+
204+
@handle_legacy_interface(weights=("pretrained", None))
205+
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
206+
weights = ViT_H_14_Weights.verify(weights)
207+
208+
return _vision_transformer(
209+
patch_size=14,
210+
num_layers=32,
211+
num_heads=16,
212+
hidden_dim=1280,
213+
mlp_dim=5120,
214+
weights=weights,
215+
progress=progress,
216+
**kwargs,
217+
)

torchvision/prototype/utils/_internal.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@ class StrEnum(enum.Enum, metaclass=StrEnumMeta):
3030

3131

3232
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
33+
if not seq:
34+
return ""
3335
if len(seq) == 1:
3436
return f"'{seq[0]}'"
3537

36-
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'"""
38+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
39+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
40+
41+
return head + tail
3742

3843

3944
def add_suggestion(

0 commit comments

Comments
 (0)