Skip to content

Commit af8b447

Browse files
authored
Merge pull request #132 from lcskrishna/cl/legacy-test-enable
Enable legacy_nn unit tests.
2 parents 5c4a6f3 + c304e87 commit af8b447

File tree

3 files changed

+113
-47
lines changed

3 files changed

+113
-47
lines changed

test/common_nn.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.cuda
99
from torch.nn.functional import _Reduction
10-
from common import TestCase, to_gpu, freeze_rng_state, is_iterable
10+
from common import TestCase, to_gpu, freeze_rng_state, is_iterable, TEST_WITH_ROCM
1111
from common_cuda import TEST_CUDA
1212
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
1313
import torch.backends.cudnn
@@ -102,11 +102,13 @@ def get_weight(m):
102102
constructor_args=(1,),
103103
input_size=(10, 20),
104104
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
105+
test_cuda = (not TEST_WITH_ROCM)
105106
),
106107
dict(
107108
module_name='Softmax2d',
108109
input_size=(1, 3, 10, 20),
109110
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, False)),
111+
test_cuda = (not TEST_WITH_ROCM)
110112
),
111113
dict(
112114
module_name='LogSoftmax',
@@ -120,6 +122,7 @@ def get_weight(m):
120122
input_size=(1, 3, 10, 20),
121123
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
122124
desc='multiparam',
125+
test_cuda = (not TEST_WITH_ROCM)
123126
),
124127
dict(
125128
module_name='ELU',
@@ -198,48 +201,56 @@ def get_weight(m):
198201
input_size=(2, 3, 4),
199202
desc='1d_multiparam',
200203
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
204+
test_cuda = (not TEST_WITH_ROCM)
201205
),
202206
dict(
203207
module_name='PReLU',
204208
input_size=(2, 3, 4, 5),
205209
desc='2d',
206210
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
211+
#test_cuda = (not TEST_WITH_ROCM)
207212
),
208213
dict(
209214
module_name='PReLU',
210215
constructor_args=(3,),
211216
input_size=(2, 3, 4, 5),
212217
desc='2d_multiparam',
213218
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
219+
test_cuda = (not TEST_WITH_ROCM)
214220
),
215221
dict(
216222
module_name='PReLU',
217223
input_size=(2, 3, 4, 5, 6),
218224
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
219225
desc='3d',
226+
#test_cuda = (not TEST_WITH_ROCM)
220227
),
221228
dict(
222229
module_name='PReLU',
223230
constructor_args=(3,),
224231
input_size=(2, 3, 4, 5, 6),
225232
desc='3d_multiparam',
226233
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
234+
test_cuda = (not TEST_WITH_ROCM)
227235
),
228236
dict(
229237
module_name='Softsign',
230238
input_size=(3, 2, 5),
231239
reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
240+
test_cuda = (not TEST_WITH_ROCM)
232241
),
233242
dict(
234243
module_name='Softmin',
235244
constructor_args=(1,),
236245
input_size=(10, 20),
246+
test_cuda = (not TEST_WITH_ROCM)
237247
),
238248
dict(
239249
module_name='Softmin',
240250
constructor_args=(1,),
241251
input_size=(2, 3, 5, 10),
242252
desc='multidim',
253+
test_cuda = (not TEST_WITH_ROCM)
243254
),
244255
dict(
245256
module_name='Tanhshrink',
@@ -576,6 +587,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
576587
reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
577588
(i.numel() if get_reduction(m) else 1),
578589
check_gradgrad=False,
590+
test_cuda = (not TEST_WITH_ROCM)
579591
),
580592
dict(
581593
module_name='BCELoss',
@@ -586,6 +598,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
586598
(i.numel() if get_reduction(m) else 1),
587599
desc='weights',
588600
check_gradgrad=False,
601+
test_cuda = (not TEST_WITH_ROCM)
589602
),
590603
dict(
591604
module_name='CrossEntropyLoss',
@@ -606,6 +619,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
606619
reference_fn=lambda i, t, m:
607620
hingeembeddingloss_reference(i, t, reduction=get_reduction(m)),
608621
check_sum_reduction=True,
622+
test_cuda = (not TEST_WITH_ROCM)
609623
),
610624
dict(
611625
module_name='HingeEmbeddingLoss',
@@ -616,6 +630,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
616630
hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
617631
desc='margin',
618632
check_sum_reduction=True,
633+
test_cuda = (not TEST_WITH_ROCM)
619634
),
620635
dict(
621636
module_name='MultiLabelMarginLoss',
@@ -642,6 +657,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
642657
target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
643658
reference_fn=lambda i, t, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()).sum() / i.numel(),
644659
check_gradgrad=False,
660+
test_cuda = (not TEST_WITH_ROCM)
645661
),
646662
dict(
647663
module_name='MultiMarginLoss',
@@ -720,6 +736,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
720736
reference_fn=lambda i, t, m:
721737
cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
722738
check_sum_reduction=True,
739+
test_cuda = (not TEST_WITH_ROCM)
723740
),
724741
dict(
725742
module_name='CosineEmbeddingLoss',
@@ -730,6 +747,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
730747
cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)),
731748
desc='margin',
732749
check_sum_reduction=True,
750+
test_cuda = (not TEST_WITH_ROCM)
733751
),
734752
dict(
735753
module_name='MarginRankingLoss',
@@ -738,6 +756,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
738756
reference_fn=lambda i, t, m:
739757
marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
740758
check_sum_reduction=True,
759+
test_cuda = (not TEST_WITH_ROCM)
741760
),
742761
dict(
743762
module_name='MarginRankingLoss',
@@ -748,6 +767,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
748767
marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
749768
desc='margin',
750769
check_sum_reduction=True,
770+
test_cuda = (not TEST_WITH_ROCM)
751771
),
752772
]
753773

test/run_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
'cuda',
4646
'distributed',
4747
'distributions',
48-
'legacy_nn',
4948
'multiprocessing',
5049
'nccl',
5150
'nn',

0 commit comments

Comments
 (0)