@@ -40,7 +40,8 @@ def get_weight(m):
40
40
module_name = 'Linear' ,
41
41
constructor_args = (10 , 8 ),
42
42
input_size = (4 , 10 ),
43
- reference_fn = lambda i , p : torch .mm (i , p [0 ].t ()) + p [1 ].view (1 , - 1 ).expand (4 , 8 )
43
+ reference_fn = lambda i , p : torch .mm (i , p [0 ].t ()) + p [1 ].view (1 , - 1 ).expand (4 , 8 ),
44
+ test_cuda = (not TEST_WITH_ROCM )
44
45
),
45
46
dict (
46
47
module_name = 'Linear' ,
@@ -115,6 +116,7 @@ def get_weight(m):
115
116
constructor_args = (1 ,),
116
117
input_size = (10 , 20 ),
117
118
reference_fn = lambda i , _ : torch .exp (i ).div_ (torch .exp (i ).sum (1 , True ).expand (10 , 20 )).log_ (),
119
+ test_cuda = (not TEST_WITH_ROCM )
118
120
),
119
121
dict (
120
122
module_name = 'LogSoftmax' ,
@@ -128,7 +130,8 @@ def get_weight(m):
128
130
module_name = 'ELU' ,
129
131
constructor_args = (2. ,),
130
132
input_size = (3 , 2 , 5 ),
131
- reference_fn = lambda x , _ : torch .where (x >= 0 , x , 2 * (x .exp () - 1 ))
133
+ reference_fn = lambda x , _ : torch .where (x >= 0 , x , 2 * (x .exp () - 1 )),
134
+ test_cuda = (not TEST_WITH_ROCM ),
132
135
),
133
136
# TODO: reference function
134
137
dict (
@@ -254,7 +257,8 @@ def get_weight(m):
254
257
),
255
258
dict (
256
259
module_name = 'Tanhshrink' ,
257
- input_size = (2 , 3 , 4 , 5 )
260
+ input_size = (2 , 3 , 4 , 5 ),
261
+ test_cuda = (not TEST_WITH_ROCM )
258
262
),
259
263
]
260
264
0 commit comments