Skip to content

Commit fe5e449

Browse files
authored
Merge branch 'main' into TestColorJitter_seeds
2 parents d1dc69c + e00d818 commit fe5e449

File tree

5 files changed

+177
-29
lines changed

5 files changed

+177
-29
lines changed

test/test_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar
4646
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
4747
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
4848

49+
@pytest.mark.parametrize("seed", range(10))
4950
@pytest.mark.parametrize("device", cpu_and_gpu())
5051
@pytest.mark.parametrize("contiguous", (True, False))
51-
def test_backward(self, device, contiguous):
52+
def test_backward(self, seed, device, contiguous):
53+
torch.random.manual_seed(seed)
5254
pool_size = 2
5355
x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
5456
if not contiguous:
@@ -845,7 +847,9 @@ def test_frozenbatchnorm2d_repr(self):
845847
expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
846848
assert repr(t) == expected_string
847849

848-
def test_frozenbatchnorm2d_eps(self):
850+
@pytest.mark.parametrize("seed", range(10))
851+
def test_frozenbatchnorm2d_eps(self, seed):
852+
torch.random.manual_seed(seed)
849853
sample_size = (4, 32, 28, 28)
850854
x = torch.rand(sample_size)
851855
state_dict = dict(

test/test_transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,8 +1750,10 @@ def test_color_jitter():
17501750
color_jitter.__repr__()
17511751

17521752

1753+
@pytest.mark.parametrize("seed", range(10))
17531754
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
1754-
def test_random_erasing():
1755+
def test_random_erasing(seed):
1756+
torch.random.manual_seed(seed)
17551757
img = torch.ones(3, 128, 128)
17561758

17571759
t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.0))
@@ -1852,8 +1854,10 @@ def test_randomperspective():
18521854
)
18531855

18541856

1857+
@pytest.mark.parametrize("seed", range(10))
18551858
@pytest.mark.parametrize("mode", ["L", "RGB", "F"])
1856-
def test_randomperspective_fill(mode):
1859+
def test_randomperspective_fill(mode, seed):
1860+
torch.random.manual_seed(seed)
18571861

18581862
# assert fill being either a Sequence or a Number
18591863
with pytest.raises(TypeError):

test/test_transforms_video.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -160,34 +160,16 @@ def test_to_tensor_video(self):
160160

161161
trans.__repr__()
162162

163-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
164-
def test_random_horizontal_flip_video(self):
165-
random_state = random.getstate()
166-
random.seed(42)
163+
@pytest.mark.parametrize("p", (0, 1))
164+
def test_random_horizontal_flip_video(self, p):
167165
clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
168166
hclip = clip.flip((-1))
169167

170-
num_samples = 250
171-
num_horizontal = 0
172-
for _ in range(num_samples):
173-
out = transforms.RandomHorizontalFlipVideo()(clip)
174-
if torch.all(torch.eq(out, hclip)):
175-
num_horizontal += 1
176-
177-
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
178-
random.setstate(random_state)
179-
assert p_value > 0.0001
180-
181-
num_samples = 250
182-
num_horizontal = 0
183-
for _ in range(num_samples):
184-
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
185-
if torch.all(torch.eq(out, hclip)):
186-
num_horizontal += 1
187-
188-
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
189-
random.setstate(random_state)
190-
assert p_value > 0.0001
168+
out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
169+
if p == 0:
170+
torch.testing.assert_close(out, clip)
171+
elif p == 1:
172+
torch.testing.assert_close(out, hclip)
191173

192174
transforms.RandomHorizontalFlipVideo().__repr__()
193175

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .fcn import *
22
from .lraspp import *
3+
from .deeplabv3 import *
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
6+
from ...transforms.presets import VocEval
7+
from .._api import Weights, WeightEntry
8+
from .._meta import _VOC_CATEGORIES
9+
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
10+
from ..resnet import resnet50, resnet101
11+
from ..resnet import ResNet50Weights, ResNet101Weights
12+
13+
14+
__all__ = [
15+
"DeepLabV3",
16+
"DeepLabV3ResNet50Weights",
17+
"DeepLabV3ResNet101Weights",
18+
"DeepLabV3MobileNetV3LargeWeights",
19+
"deeplabv3_mobilenet_v3_large",
20+
"deeplabv3_resnet50",
21+
"deeplabv3_resnet101",
22+
]
23+
24+
25+
class DeepLabV3ResNet50Weights(Weights):
26+
CocoWithVocLabels_RefV1 = WeightEntry(
27+
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
28+
transforms=partial(VocEval, resize_size=520),
29+
meta={
30+
"categories": _VOC_CATEGORIES,
31+
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
32+
"mIoU": 66.4,
33+
"acc": 92.4,
34+
},
35+
)
36+
37+
38+
class DeepLabV3ResNet101Weights(Weights):
39+
CocoWithVocLabels_RefV1 = WeightEntry(
40+
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
41+
transforms=partial(VocEval, resize_size=520),
42+
meta={
43+
"categories": _VOC_CATEGORIES,
44+
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
45+
"mIoU": 67.4,
46+
"acc": 92.4,
47+
},
48+
)
49+
50+
51+
class DeepLabV3MobileNetV3LargeWeights(Weights):
52+
CocoWithVocLabels_RefV1 = WeightEntry(
53+
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
54+
transforms=partial(VocEval, resize_size=520),
55+
meta={
56+
"categories": _VOC_CATEGORIES,
57+
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
58+
"mIoU": 60.3,
59+
"acc": 91.2,
60+
},
61+
)
62+
63+
64+
def deeplabv3_resnet50(
65+
weights: Optional[DeepLabV3ResNet50Weights] = None,
66+
weights_backbone: Optional[ResNet50Weights] = None,
67+
progress: bool = True,
68+
num_classes: int = 21,
69+
aux_loss: Optional[bool] = None,
70+
**kwargs: Any,
71+
) -> DeepLabV3:
72+
if "pretrained" in kwargs:
73+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
74+
weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
75+
76+
weights = DeepLabV3ResNet50Weights.verify(weights)
77+
if "pretrained_backbone" in kwargs:
78+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
79+
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
80+
weights_backbone = ResNet50Weights.verify(weights_backbone)
81+
82+
if weights is not None:
83+
weights_backbone = None
84+
aux_loss = True
85+
num_classes = len(weights.meta["categories"])
86+
87+
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
88+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
89+
90+
if weights is not None:
91+
model.load_state_dict(weights.state_dict(progress=progress))
92+
93+
return model
94+
95+
96+
def deeplabv3_resnet101(
97+
weights: Optional[DeepLabV3ResNet101Weights] = None,
98+
weights_backbone: Optional[ResNet101Weights] = None,
99+
progress: bool = True,
100+
num_classes: int = 21,
101+
aux_loss: Optional[bool] = None,
102+
**kwargs: Any,
103+
) -> DeepLabV3:
104+
if "pretrained" in kwargs:
105+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
106+
weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
107+
108+
weights = DeepLabV3ResNet101Weights.verify(weights)
109+
if "pretrained_backbone" in kwargs:
110+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
111+
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
112+
weights_backbone = ResNet101Weights.verify(weights_backbone)
113+
114+
if weights is not None:
115+
weights_backbone = None
116+
aux_loss = True
117+
num_classes = len(weights.meta["categories"])
118+
119+
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
120+
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
121+
122+
if weights is not None:
123+
model.load_state_dict(weights.state_dict(progress=progress))
124+
125+
return model
126+
127+
128+
def deeplabv3_mobilenet_v3_large(
129+
weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None,
130+
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
131+
progress: bool = True,
132+
num_classes: int = 21,
133+
aux_loss: Optional[bool] = None,
134+
**kwargs: Any,
135+
) -> DeepLabV3:
136+
if "pretrained" in kwargs:
137+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
138+
weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
139+
140+
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights)
141+
if "pretrained_backbone" in kwargs:
142+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
143+
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
144+
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
145+
146+
if weights is not None:
147+
weights_backbone = None
148+
aux_loss = True
149+
num_classes = len(weights.meta["categories"])
150+
151+
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
152+
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
153+
154+
if weights is not None:
155+
model.load_state_dict(weights.state_dict(progress=progress))
156+
157+
return model

0 commit comments

Comments
 (0)