From 754885c7c46d0239257402bf1afb70855334aca5 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Tue, 11 Sep 2018 14:45:54 -0700 Subject: [PATCH 1/3] enable fp16 tests for test_nn --- test/common_nn.py | 1 + test/test_nn.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/common_nn.py b/test/common_nn.py index f159fe659672c..06ca4aadce72f 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -653,6 +653,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 multilabelmarginloss_reference(i, t, reduction=get_reduction(m)), check_sum_reduction=True, check_gradgrad=False, + check_half=False ), dict( module_name='MultiLabelSoftMarginLoss', diff --git a/test/test_nn.py b/test/test_nn.py index b2597b894803f..daad6d2637e15 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4918,6 +4918,7 @@ def test_pdist_empty_row(self): inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True) self.assertTrue(gradcheck(F.pdist, (inp,))) + @skipIfRocm def test_pdist_empty_col(self): for device in device_(): inp = torch.randn(4, 0, dtype=torch.double, device=device, requires_grad=True) @@ -6572,7 +6573,6 @@ def add(test_name, fn): add(cuda_test_name + '_double', lambda self, test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.double, **kwargs)) - @skipIfRocm def test_half(self, test=test, kwargs=kwargs): test.test_cuda(self, dtype=torch.half, **kwargs) if getattr(test, 'check_half', True): From e62b4d7aa80e7e1ba96b7a613b641c374e27bb23 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Tue, 11 Sep 2018 15:24:52 -0700 Subject: [PATCH 2/3] enabled multilabelmargin loss for fp16 --- test/common_nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/common_nn.py b/test/common_nn.py index 06ca4aadce72f..f159fe659672c 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -653,7 +653,6 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0 multilabelmarginloss_reference(i, t, reduction=get_reduction(m)), check_sum_reduction=True, check_gradgrad=False, - check_half=False ), dict( module_name='MultiLabelSoftMarginLoss', From dbf77c8dd9fc1913aeb6e0161d204b2588b8b8b9 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Wed, 12 Sep 2018 10:11:43 -0700 Subject: [PATCH 3/3] removed skip for test_pdist_empty_col --- test/test_nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_nn.py b/test/test_nn.py index daad6d2637e15..4e4209edfc9d7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4918,7 +4918,6 @@ def test_pdist_empty_row(self): inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True) self.assertTrue(gradcheck(F.pdist, (inp,))) - @skipIfRocm def test_pdist_empty_col(self): for device in device_(): inp = torch.randn(4, 0, dtype=torch.double, device=device, requires_grad=True)