Skip to content

Commit d41a2f5

Browse files
goldsboroughPenghuiCheng
authored andcommitted
Add flags to fix half comparison and test (pytorch#11395)
Summary: The controller you requested could not be found. found there are some issues when using comparison operators for half types when certain THC header are included. I was able to reproduce and added a test. I also fix the issue by adding the proper definitions to avoid this issue. Reported in pytorch#10301 (comment) Related: pytorch/tutorials#292 soumith fmassa Pull Request resolved: pytorch#11395 Differential Revision: D9725102 Pulled By: goldsborough fbshipit-source-id: 630425829046bbebea3409bb792a9d62c91f41ad
1 parent f497e32 commit d41a2f5

File tree

5 files changed

+69
-2
lines changed

5 files changed

+69
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ test/data/legacy_modules.t7
3535
test/data/legacy_serialized.pt
3636
test/data/linear.pt
3737
test/htmlcov
38+
test/cpp_extensions/install/
3839
third_party/build/
3940
tools/shared/_utils_internal.py
4041
torch.egg-info/

test/cpp_extensions/half_support.cpp

Whitespace-only changes.

test/cpp_extensions/half_support.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include <torch/torch.h>
2+
3+
#include <THC/THCNumerics.cuh>
4+
5+
template <typename T, typename U>
6+
__global__ void half_test_kernel(const T* input, U* output) {
7+
if (input[0] < input[1] || input[0] >= input[1]) {
8+
output[0] = 123;
9+
}
10+
}
11+
12+
at::Tensor half_test(at::Tensor input) {
13+
auto output = at::empty(1, input.options().dtype(at::kFloat));
14+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
15+
half_test_kernel<scalar_t>
16+
<<<1, 1>>>(input.data<scalar_t>(), output.data<float>());
17+
});
18+
return output;
19+
}

test/test_cpp_extensions.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,47 @@ def test_complex_registration(self):
274274

275275
torch.empty(2, 2, dtype=torch.complex64)
276276

277+
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
278+
def test_half_support(self):
279+
'''
280+
Checks for an issue with operator< ambiguity for half when certain
281+
THC headers are included.
282+
283+
See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
284+
for the corresponding issue.
285+
'''
286+
cuda_source = '''
287+
#include <THC/THCNumerics.cuh>
288+
289+
template<typename T, typename U>
290+
__global__ void half_test_kernel(const T* input, U* output) {
291+
if (input[0] < input[1] || input[0] >= input[1]) {
292+
output[0] = 123;
293+
}
294+
}
295+
296+
at::Tensor half_test(at::Tensor input) {
297+
auto output = at::empty(1, input.options().dtype(at::kFloat));
298+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
299+
half_test_kernel<scalar_t><<<1, 1>>>(
300+
input.data<scalar_t>(),
301+
output.data<float>());
302+
});
303+
return output;
304+
}
305+
'''
306+
307+
module = torch.utils.cpp_extension.load_inline(
308+
name='half_test_extension',
309+
cpp_sources='at::Tensor half_test(at::Tensor input);',
310+
cuda_sources=cuda_source,
311+
functions=['half_test'],
312+
verbose=True)
313+
314+
x = torch.randn(3, device='cuda', dtype=torch.half)
315+
result = module.half_test(x)
316+
self.assertEqual(result[0], 123)
317+
277318

278319
if __name__ == '__main__':
279320
common.run_tests()

torch/utils/cpp_extension.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def _find_cuda_home():
6969
# it the below pattern.
7070
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
7171

72+
COMMON_NVCC_FLAGS = [
73+
'-D__CUDA_NO_HALF_OPERATORS__',
74+
'-D__CUDA_NO_HALF_CONVERSIONS__',
75+
'-D__CUDA_NO_HALF2_OPERATORS__',
76+
]
77+
7278

7379
def is_binary_build():
7480
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
@@ -165,7 +171,7 @@ def unix_wrap_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
165171
self.compiler.set_executable('compiler_so', nvcc)
166172
if isinstance(cflags, dict):
167173
cflags = cflags['nvcc']
168-
cflags += ['--compiler-options', "'-fPIC'"]
174+
cflags = COMMON_NVCC_FLAGS + ['--compiler-options', "'-fPIC'"] + cflags
169175
elif isinstance(cflags, dict):
170176
cflags = cflags['cxx']
171177
# NVCC does not allow multiple -std to be passed, so we avoid
@@ -831,7 +837,7 @@ def _write_ninja_file(path,
831837
flags = ['cflags = {}'.format(' '.join(cflags))]
832838

833839
if with_cuda:
834-
cuda_flags = common_cflags
840+
cuda_flags = common_cflags + COMMON_NVCC_FLAGS
835841
if sys.platform == 'win32':
836842
cuda_flags = _nt_quote_args(cuda_flags)
837843
else:

0 commit comments

Comments
 (0)