diff --git a/test/torchaudio_unittest/prototype/functional/autograd_test_impl.py b/test/torchaudio_unittest/prototype/functional/autograd_test_impl.py index 92a69b7875..42f515218d 100644 --- a/test/torchaudio_unittest/prototype/functional/autograd_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/autograd_test_impl.py @@ -5,6 +5,7 @@ from parameterized import parameterized from torch.autograd import gradcheck from torchaudio_unittest.common_utils import TestBaseMixin +from torch.utils.cpp_extension import ROCM_HOME class AutogradTestImpl(TestBaseMixin): @@ -24,7 +25,10 @@ def test_oscillator_bank(self, sample_rate, shape): ) amps = torch.linspace(-5, 5, numel, dtype=self.dtype, device=self.device, requires_grad=True).reshape(shape) - assert gradcheck(F.oscillator_bank, (freq, amps, sample_rate)) + atol = 1e-05 + if ROCM_HOME is not None: + atol = 1e-04 + assert gradcheck(F.oscillator_bank, (freq, amps, sample_rate), atol=atol) def test_extend_pitch(self): num_frames, num_pitches = 5, 7