From 0d11df1d35e902fbaab0dfb70735d59bf7a5d378 Mon Sep 17 00:00:00 2001 From: Johannes M Dieterich Date: Fri, 24 Aug 2018 01:26:25 -0500 Subject: [PATCH 1/2] Correctly type cast everything. --- aten/src/THC/generic/THCTensorSort.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/aten/src/THC/generic/THCTensorSort.cu b/aten/src/THC/generic/THCTensorSort.cu index b6bcf6aecb4f78..764a508073f5df 100644 --- a/aten/src/THC/generic/THCTensorSort.cu +++ b/aten/src/THC/generic/THCTensorSort.cu @@ -56,21 +56,21 @@ THC_API void THCTensor_(sortKeyValueInplace)(THCState* state, bitonicSortKVInPlace, TYPE, SIZE> \ <<>>( \ keyInfo, \ - keySlices, \ - (TYPE) keySliceSize, \ - (TYPE) keyInfo.strides[collapseKeyDim], \ + static_cast(keySlices), \ + static_cast(keySliceSize), \ + static_cast(keyInfo.strides[collapseKeyDim]), \ valueInfo, \ - (TYPE) valueInfo.strides[collapseValueDim], \ + static_cast(valueInfo.strides[collapseValueDim]), \ GTComp()); \ } else { \ bitonicSortKVInPlace, TYPE, SIZE> \ <<>>( \ keyInfo, \ - keySlices, \ - (TYPE) keySliceSize, \ - (TYPE) keyInfo.strides[collapseKeyDim], \ + static_cast(keySlices), \ + static_cast(keySliceSize), \ + static_cast(keyInfo.strides[collapseKeyDim]), \ valueInfo, \ - (TYPE) valueInfo.strides[collapseValueDim], \ + static_cast(valueInfo.strides[collapseValueDim]), \ LTComp()); \ } \ } while (0) From 3ddabadb37a0ee23ebe53fa845a7fc89e268126b Mon Sep 17 00:00:00 2001 From: Johannes M Dieterich Date: Sun, 2 Sep 2018 18:34:28 -0500 Subject: [PATCH 2/2] CUDA does not really allow passing objects by reference to __global__ functions. Submitted by: AlexVIx --- aten/src/THC/THCSortUtils.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/THC/THCSortUtils.cuh b/aten/src/THC/THCSortUtils.cuh index 518063a229c723..891e77067686e7 100644 --- a/aten/src/THC/THCSortUtils.cuh +++ b/aten/src/THC/THCSortUtils.cuh @@ -142,7 +142,7 @@ bitonicSortKVInPlace(TensorInfo keys, IndexType keySliceStride, TensorInfo values, IndexType valueSliceStride, - const Comparator& comp) { + const Comparator comp) { // Find the slice of the tensor that we are sorting const IndexType linearIndex = getLinearBlockId(); // Tiling the slices could have us be out of bounds, if there are a