diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py index e95e10d0b..a7c5b6afd 100755 --- a/torch_utils/ops/conv2d_gradfix.py +++ b/torch_utils/ops/conv2d_gradfix.py @@ -12,6 +12,7 @@ import warnings import contextlib import torch +from distutils.version import LooseVersion # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -50,7 +51,7 @@ def _should_use_custom_op(input): return False if input.device.type != 'cuda': return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): return True warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') return False diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index ca6b3413e..a675a2150 100755 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -13,6 +13,7 @@ import warnings import torch +from distutils.version import LooseVersion # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -34,7 +35,7 @@ def grid_sample(input, grid): def _should_use_custom_op(): if not enabled: return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): return True warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') return False