@@ -1149,13 +1149,15 @@ def _create_masks(image, masks):
1149
1149
1150
1150
1151
1151
class TestStochasticDepth :
1152
+ @pytest .mark .parametrize ("seed" , range (10 ))
1152
1153
@pytest .mark .parametrize ("p" , [0.2 , 0.5 , 0.8 ])
1153
1154
@pytest .mark .parametrize ("mode" , ["batch" , "row" ])
1154
- def test_stochastic_depth (self , mode , p ):
1155
+ def test_stochastic_depth_random (self , seed , mode , p ):
1156
+ torch .manual_seed (seed )
1155
1157
stats = pytest .importorskip ("scipy.stats" )
1156
1158
batch_size = 5
1157
1159
x = torch .ones (size = (batch_size , 3 , 4 , 4 ))
1158
- layer = ops .StochasticDepth (p = p , mode = mode ). to ( device = x . device , dtype = x . dtype )
1160
+ layer = ops .StochasticDepth (p = p , mode = mode )
1159
1161
layer .__repr__ ()
1160
1162
1161
1163
trials = 250
@@ -1173,7 +1175,22 @@ def test_stochastic_depth(self, mode, p):
1173
1175
num_samples += batch_size
1174
1176
1175
1177
p_value = stats .binom_test (counts , num_samples , p = p )
1176
- assert p_value > 0.0001
1178
+ assert p_value > 0.01
1179
+
1180
+ @pytest .mark .parametrize ("seed" , range (10 ))
1181
+ @pytest .mark .parametrize ("p" , (0 , 1 ))
1182
+ @pytest .mark .parametrize ("mode" , ["batch" , "row" ])
1183
+ def test_stochastic_depth (self , seed , mode , p ):
1184
+ torch .manual_seed (seed )
1185
+ batch_size = 5
1186
+ x = torch .ones (size = (batch_size , 3 , 4 , 4 ))
1187
+ layer = ops .StochasticDepth (p = p , mode = mode )
1188
+
1189
+ out = layer (x )
1190
+ if p == 0 :
1191
+ assert out .equal (x )
1192
+ elif p == 1 :
1193
+ assert out .equal (torch .zeros_like (x ))
1177
1194
1178
1195
1179
1196
class TestUtils :
0 commit comments