Skip to content

Commit bbfda42

Browse files
authored
Fix flakiness on StochasticDepth test (#4758)
* Fix flakiness on the TestStochasticDepth test. * Fix minor bug when p=1.0 * Remove device and dtype setting.
1 parent 5ea2348 commit bbfda42

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

test/test_ops.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,13 +1149,15 @@ def _create_masks(image, masks):
11491149

11501150

11511151
class TestStochasticDepth:
1152+
@pytest.mark.parametrize("seed", range(10))
11521153
@pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
11531154
@pytest.mark.parametrize("mode", ["batch", "row"])
1154-
def test_stochastic_depth(self, mode, p):
1155+
def test_stochastic_depth_random(self, seed, mode, p):
1156+
torch.manual_seed(seed)
11551157
stats = pytest.importorskip("scipy.stats")
11561158
batch_size = 5
11571159
x = torch.ones(size=(batch_size, 3, 4, 4))
1158-
layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype)
1160+
layer = ops.StochasticDepth(p=p, mode=mode)
11591161
layer.__repr__()
11601162

11611163
trials = 250
@@ -1173,7 +1175,22 @@ def test_stochastic_depth(self, mode, p):
11731175
num_samples += batch_size
11741176

11751177
p_value = stats.binom_test(counts, num_samples, p=p)
1176-
assert p_value > 0.0001
1178+
assert p_value > 0.01
1179+
1180+
@pytest.mark.parametrize("seed", range(10))
1181+
@pytest.mark.parametrize("p", (0, 1))
1182+
@pytest.mark.parametrize("mode", ["batch", "row"])
1183+
def test_stochastic_depth(self, seed, mode, p):
1184+
torch.manual_seed(seed)
1185+
batch_size = 5
1186+
x = torch.ones(size=(batch_size, 3, 4, 4))
1187+
layer = ops.StochasticDepth(p=p, mode=mode)
1188+
1189+
out = layer(x)
1190+
if p == 0:
1191+
assert out.equal(x)
1192+
elif p == 1:
1193+
assert out.equal(torch.zeros_like(x))
11771194

11781195

11791196
class TestUtils:

torchvision/ops/stochastic_depth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
3434
else:
3535
size = [1] * input.ndim
3636
noise = torch.empty(size, dtype=input.dtype, device=input.device)
37-
noise = noise.bernoulli_(survival_rate).div_(survival_rate)
37+
noise = noise.bernoulli_(survival_rate)
38+
if survival_rate > 0.0:
39+
noise.div_(survival_rate)
3840
return input * noise
3941

4042

0 commit comments

Comments
 (0)