From aef047a064fb7b33f37a64f2b780832e47fceb01 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:30:13 +0200 Subject: [PATCH 1/5] SM75 support for FP6 kernel --- benchmarks/benchmark_fp6.py | 2 +- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 2 +- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 6 ++++ torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 31 +++++++++++++++++++++ torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 17 ++++++++++- 5 files changed, 55 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index 9b8dcf3387..f84d886cb4 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -7,7 +7,7 @@ from tqdm import tqdm -def benchmark(m: int, k: int, n: int): +def benchmark(m: int, n: int, k: int): float_data = torch.randn(n, k, dtype=torch.half, device="cuda") fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2)) fp16_weight = fp6_weight.dequantize(torch.half) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 7f973e6987..2b5d4e6aa9 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -14,7 +14,7 @@ // // This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index f2c137828d..d19e0d5bf3 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -140,7 +140,9 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, for(int j=0; j= 800 cp_async_wait_all(); + #endif __syncthreads(); ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -175,12 +177,16 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + #if __CUDA_ARCH__ >= 800 cp_async_group_commit(); + #endif core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); // Barriers and Synchronizations + #if __CUDA_ARCH__ >= 800 cp_async_wait_group(); + #endif __syncthreads(); core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); // Updating global PTRs diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index 1658352ee5..bfabddcae3 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -55,6 +55,14 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ assert( warp_start_col==0 ); #endif + #if __CUDA_ARCH__ == 750 + if (TilingConfig::WARP_COL_MMA_TENSORS==1) { + // For .target sm_75, all threads must contain valid addresses for the 'ldmatrix' op. below. Otherwise, the behavior is undefined. + // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix + // To avoid this, we make threads 16-32 point to the same smem addresses as threads 0-15 by changing the lane id. + lane_id = lane_id % 16; + } + #endif int col = (lane_id%8) + (lane_id/16)*8; int row = (lane_id%16) / 8 * 8; uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row])); @@ -80,6 +88,28 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b) { + #if __CUDA_ARCH__ == 750 + // m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops. + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5}," + "{ %6 }," + "{ %7, %8, %9, %10 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5}," + "{ %6 }," + "{ %7, %8, %9, %10 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[2]), "r"(a[3]), + "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + + #else asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{ %0, %1, %2, %3}," "{ %4, %5, %6, %7 }," @@ -89,6 +119,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + #endif } #endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index a74930ba44..5892a7be33 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -39,7 +39,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, GPTR_HALF += lane_id*8; #pragma unroll for(int i=0; i(SPTR_HALF); + const float4* GPTR_VEC = reinterpret_cast(GPTR_HALF); + SPTR_VEC[0] = GPTR_VEC[0]; + } + #else cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard); + #endif SPTR_HALF += 256; // Forward 512 Bytes GPTR_HALF += 256; // Forward 512 Bytes } @@ -82,8 +90,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ Shar #pragma unroll for (int i = 0; i < MaxIteration; i++) { bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred; + #if __CUDA_ARCH__ == 750 + if (AsyncCopyPred) { + float4* SharedPtrVec = reinterpret_cast(&(*SharedPTR)[line_offset]); + const float4* GlobalPtrVec = reinterpret_cast(GlobalPTR); + SharedPtrVec[0] = GlobalPtrVec[0]; + } + #else cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); - // + #endif GlobalPTR += NumOfGroups * GlobalStride; SharedPTR += NumOfGroups; } From d39017914cddc2178bef2451527c36b2c513f529 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 26 Sep 2024 08:03:18 +0200 Subject: [PATCH 2/5] More consistent argument ordering in benchmark function --- benchmarks/benchmark_fp6.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index f84d886cb4..425507bd95 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -7,7 +7,7 @@ from tqdm import tqdm -def benchmark(m: int, n: int, k: int): +def benchmark(m: int, k: int, n: int): float_data = torch.randn(n, k, dtype=torch.half, device="cuda") fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2)) fp16_weight = fp6_weight.dequantize(torch.half) @@ -43,7 +43,7 @@ def benchmark(m: int, n: int, k: int): for m in tqdm([1 << i for i in range(10)]): for n, k in zip(n_vals, k_vals): - results.append(benchmark(m, n, k)) + results.append(benchmark(m, k, n)) df = pd.DataFrame(results) df.to_csv("fp6_llm_benchmark_results.csv", index=False) From 0a4d70e72004e71eaef35cebbe31ce90e11acb01 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:08:36 +0200 Subject: [PATCH 3/5] Add a note about SM75 support in the floatx README --- torchao/dtypes/floatx/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/dtypes/floatx/README.md b/torchao/dtypes/floatx/README.md index f4cbf51a03..af770cf65c 100644 --- a/torchao/dtypes/floatx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -43,6 +43,7 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape - Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. - Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1. - On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 for some microbenchmark results. +- FP6 is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details). ## End-to-End benchmarks From 650ba03d7df2bbb5c78e0620f71911ecc6c60927 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:14:41 +0200 Subject: [PATCH 4/5] Handle FP6 + SM75 + N>=64 edge case --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 79 +++++++++++++++++-------- 1 file changed, 54 insertions(+), 25 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 2b5d4e6aa9..be80e77a78 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -22,6 +22,22 @@ #include #include +inline bool isSM75GPU() { + int device; + cudaError_t err = cudaGetDevice(&device); + if (err != cudaSuccess) { + return false; + } + + cudaDeviceProp props; + err = cudaGetDeviceProperties(&props, device); + if (err != cudaSuccess) { + return false; + } + + return (props.major == 7) && (props.minor == 5); +} + template static void Kernel_Ex(cudaStream_t stream, const uint4 *Weight, @@ -80,38 +96,51 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; - if (Split_K == 1) { - switch (N_PowerOf2) { - case 8: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - default: if (N_PowerOf2 % 128 != 0) { - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; - } - Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { + // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. + if (Split_K == 1) { + Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + } else { + Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); } - } - else { - switch (N_PowerOf2) { - case 8: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - default: if (N_PowerOf2 % 128 != 0) { - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; - } - Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + } else { + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + default: if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + } + } + else { + switch (N_PowerOf2) { + case 8: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + default: if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + } } + } + + if (Split_K != 1) { // Reduction for SplitK dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); dim3 BlockDim(WARP_SIZE, 1, 1); SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); } + return cudaGetLastError(); } From 56718f9c6d38e0cd38c7929a1ca6d86b4b6d5e58 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:28:58 +0200 Subject: [PATCH 5/5] Document changes made for FP6 SM75 support --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 4 ++++ torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 4 ++++ torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 4 ++++ torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index be80e77a78..b4cbe99160 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -13,6 +13,10 @@ // limitations under the License. // // This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu +// +// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): +// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory +// #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index d19e0d5bf3..600debd4a0 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -13,6 +13,10 @@ // limitations under the License. // // This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh +// +// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): +// - Added __CUDA_ARCH__ guards such that async operations are only executed for SM80 and up +// #include "configs.h" #include "utils_gmem.cuh" diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index bfabddcae3..dededcf19d 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -13,6 +13,10 @@ // limitations under the License. // // This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh +// +// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): +// - Replaced m16n8k16 Tensor core operation with two m16n8k8 operations +// - Accounted for a difference in expected parameters for the ldmatrix operation /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index 5892a7be33..f2af30733f 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -13,6 +13,10 @@ // limitations under the License. // // This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh +// +// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): +// - Replaced asynchronous copy operations with vectorized loads +// #ifndef UTILS_GMEM_CUH #define UTILS_GMEM_CUH