diff --git a/advanced_source/cpp_extension.rst b/advanced_source/cpp_extension.rst index 9fe1db38d7a..265b3774f88 100644 --- a/advanced_source/cpp_extension.rst +++ b/advanced_source/cpp_extension.rst @@ -1009,6 +1009,18 @@ simpler:: lltm = load(name='lltm', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu']) +Note that when constructing :func:`CUDAExtension`, you might have to pass +``extra_compile_args`` to :func:`CUDAExtension`, to avoid collisions +between half operator overloading in pytorch and cuda headers. (More on this +issue can be found here: +https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333):: + + extra_compile_args = {'cxx':[], + 'nvcc':['-DCUDA_HAS_FP16=1', + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__']} + Performance Comparison **********************