Skip to content

Commit 02262a2

Browse files
lcskrishnaiotamudelta
authored andcommitted
enabled cuda tests (#248)
1 parent 3791971 commit 02262a2

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

test/test_cuda.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def tmp(t):
344344
('max', small_3d_unique, lambda t: [-1], 'neg_dim'),
345345
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
346346
('min', small_3d_unique, lambda t: []),
347-
('min', small_3d_unique, lambda t: [1], 'dim', types, False, skipIfRocm),
348-
('min', small_3d_unique, lambda t: [-1], 'neg_dim', types, False, skipIfRocm),
347+
('min', small_3d_unique, lambda t: [1], 'dim'),
348+
('min', small_3d_unique, lambda t: [-1], 'neg_dim'),
349349
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
350350
('mean', small_3d, lambda t: []),
351351
('mean', small_3d, lambda t: [-1], 'neg_dim'),
@@ -393,11 +393,11 @@ def tmp(t):
393393
('size', new_t(1, 2, 3, 4), lambda t: [],),
394394
('size', new_t(1, 2, 3, 4), lambda t: [1], 'dim'),
395395
('size', new_t(1, 2, 3, 4), lambda t: [-2], 'neg_dim'),
396-
('sort', small_3d_unique, lambda t: [], '', types, False, skipIfRocm),
397-
('sort', small_3d_unique, lambda t: [1], 'dim', types, False, skipIfRocm),
398-
('sort', small_3d_unique, lambda t: [-1], 'neg_dim', types, False, skipIfRocm),
399-
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending', types, False, skipIfRocm),
400-
('sort', small_3d_unique, lambda t: [-1, True], 'neg_dim_descending', types, False, skipIfRocm),
396+
('sort', small_3d_unique, lambda t: [], ''),
397+
('sort', small_3d_unique, lambda t: [1], 'dim'),
398+
('sort', small_3d_unique, lambda t: [-1], 'neg_dim'),
399+
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
400+
('sort', small_3d_unique, lambda t: [-1, True], 'neg_dim_descending'),
401401
('split', small_3d, lambda t: [2],),
402402
('split', small_3d, lambda t: [2, 1], 'dim'),
403403
('split', small_3d, lambda t: [2, -3], 'neg_dim'),
@@ -409,9 +409,9 @@ def tmp(t):
409409
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],),
410410
('transpose', new_t(1, 2, 3, 4), lambda t: [-1, -2], 'neg_dim'),
411411
('to_list', small_3d, lambda t: [],),
412-
('topk', small_3d_unique, lambda t: [2, 1, False, True], 'dim_sort', types, False, skipIfRocm),
413-
('topk', small_3d_unique, lambda t: [2, -1, False, True], 'neg_dim_sort', types, False, skipIfRocm),
414-
('topk', small_3d_unique, lambda t: [2, 1, True, True], 'dim_desc_sort', types, False, skipIfRocm),
412+
('topk', small_3d_unique, lambda t: [2, 1, False, True], 'dim_sort', types, False, "skipIfRocm:HalfTensor"),
413+
('topk', small_3d_unique, lambda t: [2, -1, False, True], 'neg_dim_sort', types, False, "skipIfRocm:HalfTensor"),
414+
('topk', small_3d_unique, lambda t: [2, 1, True, True], 'dim_desc_sort', types, False, "skipIfRocm:HalfTensor"),
415415
('trace', medium_2d, lambda t: []),
416416
('tril', medium_2d, lambda t: [],),
417417
('tril', medium_2d_expanded, lambda t: [], 'zero_stride', types, True),
@@ -1210,11 +1210,9 @@ def test_cat(self):
12101210
z = torch.cat([x, y])
12111211
self.assertEqual(z.size(), (21, SIZE, SIZE))
12121212

1213-
@skipIfRocm
12141213
def test_cat_empty_legacy(self):
12151214
TestTorch._test_cat_empty_legacy(self, use_cuda=True)
12161215

1217-
@skipIfRocm
12181216
def test_cat_empty(self):
12191217
TestTorch._test_cat_empty(self, use_cuda=True)
12201218

@@ -1708,7 +1706,6 @@ def test_btrisolve(self):
17081706
def test_dim_reduction(self):
17091707
TestTorch._test_dim_reduction(self, lambda t: t.cuda())
17101708

1711-
@skipIfRocm
17121709
def test_tensor_gather(self):
17131710
TestTorch._test_gather(self, lambda t: t.cuda(), False)
17141711

0 commit comments

Comments
 (0)