diff --git a/test/test_nn.py b/test/test_nn.py index b2597b894803f6..4e4209edfc9d70 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6572,7 +6572,6 @@ def add(test_name, fn): add(cuda_test_name + '_double', lambda self, test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.double, **kwargs)) - @skipIfRocm def test_half(self, test=test, kwargs=kwargs): test.test_cuda(self, dtype=torch.half, **kwargs) if getattr(test, 'check_half', True):