@@ -344,8 +344,8 @@ def tmp(t):
344
344
('max' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' ),
345
345
('max' , medium_2d , lambda t : [medium_2d (t )], 'elementwise' ),
346
346
('min' , small_3d_unique , lambda t : []),
347
- ('min' , small_3d_unique , lambda t : [1 ], 'dim' , types , False , skipIfRocm ),
348
- ('min' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' , types , False , skipIfRocm ),
347
+ ('min' , small_3d_unique , lambda t : [1 ], 'dim' ),
348
+ ('min' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' ),
349
349
('min' , medium_2d , lambda t : [medium_2d (t )], 'elementwise' ),
350
350
('mean' , small_3d , lambda t : []),
351
351
('mean' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
@@ -393,11 +393,11 @@ def tmp(t):
393
393
('size' , new_t (1 , 2 , 3 , 4 ), lambda t : [],),
394
394
('size' , new_t (1 , 2 , 3 , 4 ), lambda t : [1 ], 'dim' ),
395
395
('size' , new_t (1 , 2 , 3 , 4 ), lambda t : [- 2 ], 'neg_dim' ),
396
- ('sort' , small_3d_unique , lambda t : [], '' , types , False , skipIfRocm ),
397
- ('sort' , small_3d_unique , lambda t : [1 ], 'dim' , types , False , skipIfRocm ),
398
- ('sort' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' , types , False , skipIfRocm ),
399
- ('sort' , small_3d_unique , lambda t : [1 , True ], 'dim_descending' , types , False , skipIfRocm ),
400
- ('sort' , small_3d_unique , lambda t : [- 1 , True ], 'neg_dim_descending' , types , False , skipIfRocm ),
396
+ ('sort' , small_3d_unique , lambda t : [], '' ),
397
+ ('sort' , small_3d_unique , lambda t : [1 ], 'dim' ),
398
+ ('sort' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' ),
399
+ ('sort' , small_3d_unique , lambda t : [1 , True ], 'dim_descending' ),
400
+ ('sort' , small_3d_unique , lambda t : [- 1 , True ], 'neg_dim_descending' ),
401
401
('split' , small_3d , lambda t : [2 ],),
402
402
('split' , small_3d , lambda t : [2 , 1 ], 'dim' ),
403
403
('split' , small_3d , lambda t : [2 , - 3 ], 'neg_dim' ),
@@ -409,9 +409,9 @@ def tmp(t):
409
409
('transpose' , new_t (1 , 2 , 3 , 4 ), lambda t : [1 , 2 ],),
410
410
('transpose' , new_t (1 , 2 , 3 , 4 ), lambda t : [- 1 , - 2 ], 'neg_dim' ),
411
411
('to_list' , small_3d , lambda t : [],),
412
- ('topk' , small_3d_unique , lambda t : [2 , 1 , False , True ], 'dim_sort' , types , False , skipIfRocm ),
413
- ('topk' , small_3d_unique , lambda t : [2 , - 1 , False , True ], 'neg_dim_sort' , types , False , skipIfRocm ),
414
- ('topk' , small_3d_unique , lambda t : [2 , 1 , True , True ], 'dim_desc_sort' , types , False , skipIfRocm ),
412
+ ('topk' , small_3d_unique , lambda t : [2 , 1 , False , True ], 'dim_sort' , types , False , " skipIfRocm:HalfTensor" ),
413
+ ('topk' , small_3d_unique , lambda t : [2 , - 1 , False , True ], 'neg_dim_sort' , types , False , " skipIfRocm:HalfTensor" ),
414
+ ('topk' , small_3d_unique , lambda t : [2 , 1 , True , True ], 'dim_desc_sort' , types , False , " skipIfRocm:HalfTensor" ),
415
415
('trace' , medium_2d , lambda t : []),
416
416
('tril' , medium_2d , lambda t : [],),
417
417
('tril' , medium_2d_expanded , lambda t : [], 'zero_stride' , types , True ),
@@ -1210,11 +1210,9 @@ def test_cat(self):
1210
1210
z = torch .cat ([x , y ])
1211
1211
self .assertEqual (z .size (), (21 , SIZE , SIZE ))
1212
1212
1213
- @skipIfRocm
1214
1213
def test_cat_empty_legacy (self ):
1215
1214
TestTorch ._test_cat_empty_legacy (self , use_cuda = True )
1216
1215
1217
- @skipIfRocm
1218
1216
def test_cat_empty (self ):
1219
1217
TestTorch ._test_cat_empty (self , use_cuda = True )
1220
1218
@@ -1708,7 +1706,6 @@ def test_btrisolve(self):
1708
1706
def test_dim_reduction (self ):
1709
1707
TestTorch ._test_dim_reduction (self , lambda t : t .cuda ())
1710
1708
1711
- @skipIfRocm
1712
1709
def test_tensor_gather (self ):
1713
1710
TestTorch ._test_gather (self , lambda t : t .cuda (), False )
1714
1711
0 commit comments