Skip to content

Commit 3636c54

Browse files
authored
Merge pull request #108 from jithunnair-amd/enable_unit_tests_for_rocm_2
Enable test_torch, test_dataloader, test_indexing and test_utils …
2 parents bfc6c3a + f912d6a commit 3636c54

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

test/run_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,14 @@
4343
'c10d',
4444
'cpp_extensions',
4545
'cuda',
46-
'dataloader',
4746
'distributed',
4847
'distributions',
49-
'indexing',
5048
'jit',
5149
'legacy_nn',
5250
'multiprocessing',
5351
'nccl',
5452
'nn',
5553
'sparse',
56-
'torch',
5754
'utils',
5855
]
5956

test/test_torch.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,7 @@ def test_multidim(x, dim):
872872
expected = fn(y, 1, keepdim=False)
873873
self.assertEqual(x[:, 1], expected, '{} with out= kwarg'.format(fn_name))
874874

875+
@skipIfRocm
875876
def test_dim_reduction(self):
876877
self._test_dim_reduction(self, lambda t: t)
877878

@@ -938,6 +939,7 @@ def test_reduction_empty(self):
938939
self.assertEqual(torch.ones((2, 1, 4), device=device), xb.all(1, keepdim=True))
939940
self.assertEqual(torch.ones((), device=device), xb.all())
940941

942+
@skipIfRocm
941943
def test_pairwise_distance_empty(self):
942944
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
943945
for device in devices:
@@ -2110,6 +2112,7 @@ def get_int64_dtype(dtype):
21102112
dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
21112113
int64_dtype, layout, device, fv + 5, False)
21122114

2115+
@skipIfRocm
21132116
def test_empty_full(self):
21142117
self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cpu'))
21152118
if torch.cuda.device_count() > 0:
@@ -2248,6 +2251,7 @@ def test_tensor_factory_cuda_type(self):
22482251
self.assertTrue(x.is_cuda)
22492252
torch.set_default_tensor_type(saved_type)
22502253

2254+
@skipIfRocm
22512255
def test_tensor_factories_empty(self):
22522256
# ensure we can create empty tensors from each factory function
22532257
shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)]
@@ -3184,6 +3188,7 @@ def check_order(a, b):
31843188
seen.add(ixx[k][j])
31853189
self.assertEqual(len(seen), size)
31863190

3191+
@skipIfRocm
31873192
def test_sort(self):
31883193
SIZE = 4
31893194
x = torch.rand(SIZE, SIZE)
@@ -3297,6 +3302,7 @@ def test_topk_noncontiguous_gpu(self):
32973302
self.assertEqual(top1, top2)
32983303
self.assertEqual(idx1, idx2)
32993304

3305+
@skipIfRocm
33003306
def test_kthvalue(self):
33013307
SIZE = 50
33023308
x = torch.rand(SIZE, SIZE, SIZE)
@@ -3341,6 +3347,7 @@ def test_kthvalue(self):
33413347
self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0)
33423348
self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0)
33433349

3350+
@skipIfRocm
33443351
def test_median(self):
33453352
for size in (155, 156):
33463353
x = torch.rand(size, size)
@@ -3376,6 +3383,7 @@ def test_median(self):
33763383
# input unchanged
33773384
self.assertEqual(x, x0, 0)
33783385

3386+
@skipIfRocm
33793387
def test_mode(self):
33803388
x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
33813389
x[:2] = 1
@@ -3539,6 +3547,7 @@ def test_narrow(self):
35393547
self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]]))
35403548
self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]]))
35413549

3550+
@skipIfRocm
35423551
def test_narrow_empty(self):
35433552
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
35443553
for device in devices:
@@ -3829,6 +3838,7 @@ def _test_gesv_batched_dims(self, cast):
38293838
self.assertEqual(x.data, cast(x_exp))
38303839

38313840
@skipIfNoLapack
3841+
@skipIfRocm
38323842
def test_gesv_batched_dims(self):
38333843
self._test_gesv_batched_dims(self, lambda t: t)
38343844

@@ -4177,6 +4187,7 @@ def test_eig(self):
41774187
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
41784188

41794189
@skipIfNoLapack
4190+
@skipIfRocm
41804191
def test_symeig(self):
41814192
xval = torch.rand(100, 3)
41824193
cov = torch.mm(xval.t(), xval)
@@ -4908,6 +4919,7 @@ def tset_potri(self):
49084919
self.assertLessEqual(inv0.dist(inv1), 1e-12)
49094920

49104921
@skipIfNoLapack
4922+
@skipIfRocm
49114923
def test_pstrf(self):
49124924
def checkPsdCholesky(a, uplo, inplace):
49134925
if inplace:
@@ -6116,6 +6128,7 @@ def test_empty_reshape(self):
61166128
# match NumPy semantics -- don't infer the size of dimension with a degree of freedom
61176129
self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
61186130

6131+
@skipIfRocm
61196132
def test_tensor_shape_empty(self):
61206133
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
61216134
for device in devices:
@@ -6304,6 +6317,7 @@ def test_dim_function_empty(self):
63046317
c = torch.randn((0, 1, 2), device=device)
63056318
self.assertEqual(c, c.index_select(0, ind_empty))
63066319

6320+
@skipIfRocm
63076321
def test_blas_empty(self):
63086322
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
63096323
for device in devices:
@@ -6373,6 +6387,7 @@ def fn(torchfn, *args):
63736387
A_LU, pivots = fn(torch.btrifact, (2, 0, 0))
63746388
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
63756389

6390+
@skipIfRocm
63766391
def test_blas_alpha_beta_empty(self):
63776392
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
63786393
for device in devices:
@@ -7719,7 +7734,6 @@ def test_empty_like(self):
77197734
self.assertEqual(torch.empty_like(a).type(), a.type())
77207735

77217736
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
7722-
@skipIfRocm
77237737
def test_pin_memory(self):
77247738
x = torch.randn(3, 5)
77257739
self.assertFalse(x.is_pinned())
@@ -7887,7 +7901,6 @@ def test_from_numpy(self):
78877901
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
78887902

78897903
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
7890-
@skipIfRocm
78917904
def test_ctor_with_numpy_array(self):
78927905
dtypes = [
78937906
np.double,

0 commit comments

Comments
 (0)