Skip to content

Commit c5e00e3

Browse files
authored
Fix types and warp sizes for ROCm (ROCm 256)
* Correct the warp size for current AMD GPUs. * Fix copy paste error located by Jithun Nair. * Correct the wrong typing explicitly.
1 parent 1a0d82e commit c5e00e3

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

aten/src/THC/generic/THCTensorIndex.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,16 +560,17 @@ void THCTensor_(indexSelect)(THCState *state, THCTensor *dst, THCTensor *src, in
560560
indexSelectSmallIndex<TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
561561
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
562562
dstInfo, srcInfo, indicesInfo, \
563-
dstSelectDim, srcSelectDim, sliceSize, srcSelectDimSize);
563+
dstSelectDim, srcSelectDim, static_cast<TYPE>(sliceSize), \
564+
srcSelectDimSize);
564565

565566
#define LARGE_INDEX(TENSOR_TYPE, TYPE, \
566567
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \
567568
indexSelectLargeIndex<TENSOR_TYPE, TYPE, \
568569
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR> \
569570
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
570571
dstInfo, srcInfo, indicesInfo, \
571-
dstSelectDim, srcSelectDim, dstTotalSize, \
572-
(IDX_IS_MAJOR) ? sliceSize : numIndices, \
572+
dstSelectDim, srcSelectDim, static_cast<TYPE>(dstTotalSize), \
573+
static_cast<TYPE>((IDX_IS_MAJOR) ? sliceSize : numIndices), \
573574
srcSelectDimSize);
574575

575576
dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));

aten/src/THCUNN/LookupTableBag.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
#include "THCHalfAutoNumerics.cuh"
1616
#include "THCTensorSort.cuh"
1717

18+
#if defined(__HIP_PLATFORM_HCC__)
19+
const int WARP_SIZE = 64;
20+
#else
1821
const int WARP_SIZE = 32;
22+
#endif
1923
const int MODE_SUM = 0;
2024
const int MODE_MEAN = 1;
2125

cmake/public/LoadHIP.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ ENDIF()
4747

4848
# ROCFFT_PATH
4949
IF(NOT DEFINED ENV{ROCFFT_PATH})
50-
SET(ROCBLAS_PATH ${ROCM_PATH}/rocfft)
50+
SET(ROCFFT_PATH ${ROCM_PATH}/rocfft)
5151
ELSE()
5252
SET(ROCFFT_PATH $ENV{ROCFFT_PATH})
5353
ENDIF()

0 commit comments

Comments
 (0)