Skip to content

Commit f2f833f

Browse files
authored
Merge pull request #138 from lcskrishna/cl/enable-nn-test
enable test_nn unit tests.
2 parents af8b447 + 77a6f03 commit f2f833f

File tree

4 files changed

+241
-51
lines changed

4 files changed

+241
-51
lines changed

test/common_cuda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import torch
44
import torch.cuda
5+
from common import TEST_WITH_ROCM
56

67

78
TEST_CUDA = torch.cuda.is_available()
89
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
910
CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0")
10-
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))
11+
TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
1112
TEST_CUDNN_VERSION = TEST_CUDNN and torch.backends.cudnn.version()
1213

1314

test/common_nn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def get_weight(m):
4040
module_name='Linear',
4141
constructor_args=(10, 8),
4242
input_size=(4, 10),
43-
reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
43+
reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
44+
test_cuda = (not TEST_WITH_ROCM)
4445
),
4546
dict(
4647
module_name='Linear',
@@ -115,6 +116,7 @@ def get_weight(m):
115116
constructor_args=(1,),
116117
input_size=(10, 20),
117118
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
119+
test_cuda = (not TEST_WITH_ROCM)
118120
),
119121
dict(
120122
module_name='LogSoftmax',
@@ -128,7 +130,8 @@ def get_weight(m):
128130
module_name='ELU',
129131
constructor_args=(2.,),
130132
input_size=(3, 2, 5),
131-
reference_fn=lambda x, _: torch.where(x >= 0, x, 2 * (x.exp() - 1))
133+
reference_fn=lambda x, _: torch.where(x >= 0, x, 2 * (x.exp() - 1)),
134+
test_cuda=(not TEST_WITH_ROCM),
132135
),
133136
# TODO: reference function
134137
dict(
@@ -254,7 +257,8 @@ def get_weight(m):
254257
),
255258
dict(
256259
module_name='Tanhshrink',
257-
input_size=(2, 3, 4, 5)
260+
input_size=(2, 3, 4, 5),
261+
test_cuda = (not TEST_WITH_ROCM)
258262
),
259263
]
260264

test/run_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
'distributions',
4848
'multiprocessing',
4949
'nccl',
50-
'nn',
5150
'utils',
5251
]
5352

0 commit comments

Comments
 (0)