7
7
import torch
8
8
import torch .cuda
9
9
from torch .nn .functional import _Reduction
10
- from common import TestCase , to_gpu , freeze_rng_state , is_iterable
10
+ from common import TestCase , to_gpu , freeze_rng_state , is_iterable , TEST_WITH_ROCM
11
11
from common_cuda import TEST_CUDA
12
12
from torch .autograd .gradcheck import get_numerical_jacobian , iter_tensors
13
13
import torch .backends .cudnn
@@ -102,11 +102,13 @@ def get_weight(m):
102
102
constructor_args = (1 ,),
103
103
input_size = (10 , 20 ),
104
104
reference_fn = lambda i , _ : torch .exp (i ).div (torch .exp (i ).sum (1 , True ).expand (10 , 20 )),
105
+ test_cuda = (not TEST_WITH_ROCM )
105
106
),
106
107
dict (
107
108
module_name = 'Softmax2d' ,
108
109
input_size = (1 , 3 , 10 , 20 ),
109
110
reference_fn = lambda i , _ : torch .exp (i ).div (torch .exp (i ).sum (1 , False )),
111
+ test_cuda = (not TEST_WITH_ROCM )
110
112
),
111
113
dict (
112
114
module_name = 'LogSoftmax' ,
@@ -120,6 +122,7 @@ def get_weight(m):
120
122
input_size = (1 , 3 , 10 , 20 ),
121
123
reference_fn = lambda i , _ : torch .exp (i ).div_ (torch .exp (i ).sum (1 , False )).log_ (),
122
124
desc = 'multiparam' ,
125
+ test_cuda = (not TEST_WITH_ROCM )
123
126
),
124
127
dict (
125
128
module_name = 'ELU' ,
@@ -198,48 +201,56 @@ def get_weight(m):
198
201
input_size = (2 , 3 , 4 ),
199
202
desc = '1d_multiparam' ,
200
203
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
204
+ test_cuda = (not TEST_WITH_ROCM )
201
205
),
202
206
dict (
203
207
module_name = 'PReLU' ,
204
208
input_size = (2 , 3 , 4 , 5 ),
205
209
desc = '2d' ,
206
210
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
211
+ #test_cuda = (not TEST_WITH_ROCM)
207
212
),
208
213
dict (
209
214
module_name = 'PReLU' ,
210
215
constructor_args = (3 ,),
211
216
input_size = (2 , 3 , 4 , 5 ),
212
217
desc = '2d_multiparam' ,
213
218
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
219
+ test_cuda = (not TEST_WITH_ROCM )
214
220
),
215
221
dict (
216
222
module_name = 'PReLU' ,
217
223
input_size = (2 , 3 , 4 , 5 , 6 ),
218
224
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
219
225
desc = '3d' ,
226
+ #test_cuda = (not TEST_WITH_ROCM)
220
227
),
221
228
dict (
222
229
module_name = 'PReLU' ,
223
230
constructor_args = (3 ,),
224
231
input_size = (2 , 3 , 4 , 5 , 6 ),
225
232
desc = '3d_multiparam' ,
226
233
reference_fn = lambda i , p : torch .clamp (i , min = 0 ) + torch .clamp (i , max = 0 ) * p [0 ][0 ],
234
+ test_cuda = (not TEST_WITH_ROCM )
227
235
),
228
236
dict (
229
237
module_name = 'Softsign' ,
230
238
input_size = (3 , 2 , 5 ),
231
239
reference_fn = lambda i , _ : i .div (1 + torch .abs (i )),
240
+ test_cuda = (not TEST_WITH_ROCM )
232
241
),
233
242
dict (
234
243
module_name = 'Softmin' ,
235
244
constructor_args = (1 ,),
236
245
input_size = (10 , 20 ),
246
+ test_cuda = (not TEST_WITH_ROCM )
237
247
),
238
248
dict (
239
249
module_name = 'Softmin' ,
240
250
constructor_args = (1 ,),
241
251
input_size = (2 , 3 , 5 , 10 ),
242
252
desc = 'multidim' ,
253
+ test_cuda = (not TEST_WITH_ROCM )
243
254
),
244
255
dict (
245
256
module_name = 'Tanhshrink' ,
@@ -576,6 +587,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
576
587
reference_fn = lambda i , t , m : - (t * i .log () + (1 - t ) * (1 - i ).log ()).sum () /
577
588
(i .numel () if get_reduction (m ) else 1 ),
578
589
check_gradgrad = False ,
590
+ test_cuda = (not TEST_WITH_ROCM )
579
591
),
580
592
dict (
581
593
module_name = 'BCELoss' ,
@@ -586,6 +598,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
586
598
(i .numel () if get_reduction (m ) else 1 ),
587
599
desc = 'weights' ,
588
600
check_gradgrad = False ,
601
+ test_cuda = (not TEST_WITH_ROCM )
589
602
),
590
603
dict (
591
604
module_name = 'CrossEntropyLoss' ,
@@ -606,6 +619,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
606
619
reference_fn = lambda i , t , m :
607
620
hingeembeddingloss_reference (i , t , reduction = get_reduction (m )),
608
621
check_sum_reduction = True ,
622
+ test_cuda = (not TEST_WITH_ROCM )
609
623
),
610
624
dict (
611
625
module_name = 'HingeEmbeddingLoss' ,
@@ -616,6 +630,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
616
630
hingeembeddingloss_reference (i , t , margin = 0.5 , reduction = get_reduction (m )),
617
631
desc = 'margin' ,
618
632
check_sum_reduction = True ,
633
+ test_cuda = (not TEST_WITH_ROCM )
619
634
),
620
635
dict (
621
636
module_name = 'MultiLabelMarginLoss' ,
@@ -642,6 +657,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
642
657
target_fn = lambda : torch .rand (5 , 10 ).mul (2 ).floor (),
643
658
reference_fn = lambda i , t , m : - (t * i .sigmoid ().log () + (1 - t ) * (- i ).sigmoid ().log ()).sum () / i .numel (),
644
659
check_gradgrad = False ,
660
+ test_cuda = (not TEST_WITH_ROCM )
645
661
),
646
662
dict (
647
663
module_name = 'MultiMarginLoss' ,
@@ -720,6 +736,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
720
736
reference_fn = lambda i , t , m :
721
737
cosineembeddingloss_reference (i [0 ], i [1 ], t , reduction = get_reduction (m )),
722
738
check_sum_reduction = True ,
739
+ test_cuda = (not TEST_WITH_ROCM )
723
740
),
724
741
dict (
725
742
module_name = 'CosineEmbeddingLoss' ,
@@ -730,6 +747,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
730
747
cosineembeddingloss_reference (i [0 ], i [1 ], t , margin = 0.7 , reduction = get_reduction (m )),
731
748
desc = 'margin' ,
732
749
check_sum_reduction = True ,
750
+ test_cuda = (not TEST_WITH_ROCM )
733
751
),
734
752
dict (
735
753
module_name = 'MarginRankingLoss' ,
@@ -738,6 +756,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
738
756
reference_fn = lambda i , t , m :
739
757
marginrankingloss_reference (i [0 ], i [1 ], t , reduction = get_reduction (m )),
740
758
check_sum_reduction = True ,
759
+ test_cuda = (not TEST_WITH_ROCM )
741
760
),
742
761
dict (
743
762
module_name = 'MarginRankingLoss' ,
@@ -748,6 +767,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
748
767
marginrankingloss_reference (i [0 ], i [1 ], t , margin = 0.5 , reduction = get_reduction (m )),
749
768
desc = 'margin' ,
750
769
check_sum_reduction = True ,
770
+ test_cuda = (not TEST_WITH_ROCM )
751
771
),
752
772
]
753
773
0 commit comments