From f50d8d745e4dfa8c5d0c8ceed2e7cfa7f27d87c7 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:04:21 +0200 Subject: [PATCH 01/30] Add FP6 benchmark option to use BF16 --- benchmarks/benchmark_fp6.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index c6d28c0bd1..c5ae9bfef6 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -7,12 +7,13 @@ from tqdm import tqdm -def benchmark(m: int, k: int, n: int): - float_data = torch.randn(n, k, dtype=torch.half, device="cuda") +def benchmark(m: int, k: int, n: int, use_bf16=False): + dtype = torch.bfloat16 if use_bf16 else torch.half + float_data = torch.randn(n, k, dtype=dtype, device="cuda") fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2)) - fp16_weight = fp6_weight.dequantize(torch.half) + fp16_weight = fp6_weight.dequantize(dtype) - fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + fp16_act = torch.randn(m, k, dtype=dtype, device="cuda") fp6_output = F.linear(fp16_act, fp6_weight) fp16_output = F.linear(fp16_act, fp16_weight) @@ -28,7 +29,7 @@ def benchmark(m: int, k: int, n: int): "k": k, "n": n, "fp6_latency (ms)": fp6_time, - "fp16_latency (ms)": fp16_time, + f"{'bf16' if use_bf16 else 'fp16'}_latency (ms)": fp16_time, "speedup (d/s)": fp16_time / fp6_time, "correct": correct, } @@ -39,11 +40,13 @@ def benchmark(m: int, k: int, n: int): k_vals = (8192, 8192, 8192, 28672) n_vals = (8192, 10240, 57344, 8192) + use_bf16 = True + results = [] for m in tqdm([1 << i for i in range(10)]): for n, k in zip(n_vals, k_vals): - results.append(benchmark(m, k, n)) + results.append(benchmark(m, k, n, use_bf16=use_bf16)) df = pd.DataFrame(results) df.to_csv("fp6_llm_benchmark_results.csv", index=False) From a71437794ee71214d15aeb6d5739ca1aefbb22ff Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:41:49 +0200 Subject: [PATCH 02/30] Change dequant bit-shifting logic for BF16 --- .../csrc/cuda/fp6_llm/utils_parallel_dequant.cuh | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 4c8c39603e..405aa84f36 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -30,9 +30,10 @@ template __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) { // - constexpr int RIGHT_SHIFT = 5 - EXPONENT; + constexpr bool USE_BF16 = true; // TODO: don't hardcode here + constexpr int RIGHT_SHIFT = USE_BF16 ? 8 - EXPONENT : 5 - EXPONENT; constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; + constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; // NB: arithmetic shift, not logical constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | MASK3 >> 16; // @@ -46,13 +47,16 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { - constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1)); + constexpr bool USE_BF16 = true; // TODO: don't hardcode here + constexpr EXP_16 = USE_BF16 ? 8 : 5; + constexpr int BIAS_OFFSET = (int(1) << (EXP_16-1)) - (int(1) << (EXPONENT-1)); constexpr int BIAS = int(1) << BIAS_OFFSET; // half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); + // TODO: should these ops be bf16-specific? output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); return output; @@ -104,9 +108,12 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) if(i%2==1) Frag_PTR_4bit++; else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; } + // Packed_FP6 now contains 4x 1234 5600 // uint32_t out1, out2; FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); + // out1 now contains 2 FP16 values, as shown by R1 in figure 6 + // out2 now contains 2 FP16 values, as shown by R2 in figure 6 // *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16 scales OutputRegs += 1; From 5af3b7eef67003a14f0376e434fc6695e76eb4dc Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Mon, 14 Oct 2024 08:57:24 +0200 Subject: [PATCH 03/30] Modify dequant + tensor core ops for bf16 --- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 8 +++-- torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 31 +++++++++++++------ .../cuda/fp6_llm/utils_parallel_dequant.cuh | 25 ++++++++++----- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 600debd4a0..560ca107cc 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -212,7 +212,11 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, #pragma unroll for(size_t j=threadIdx.x%WARP_SIZE; j::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); - else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; + if constexpr (std::is_same::value) + BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); + else if constexpr (std::is_same::value) + BlockGlobalPTR[j+i*M_Global] = __float2bfloat16_rn(smem_CFrag[i][j]); + else + BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } } diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index dededcf19d..9708222a88 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -92,6 +92,7 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b) { + constexpr bool USE_BF16 = true; // TODO: don't hardcode here #if __CUDA_ARCH__ == 750 // m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops. asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" @@ -114,15 +115,27 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); #else - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{ %0, %1, %2, %3}," - "{ %4, %5, %6, %7 }," - "{ %8, %9 }," - "{ %10, %11, %12, %13 };" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), - "r"(b[0]), "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + if (USE_BF16) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } else { // FP16 + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } #endif } diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 405aa84f36..c97c549bc1 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -20,6 +20,7 @@ #include #include +#include #include /* @@ -48,17 +49,27 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { constexpr bool USE_BF16 = true; // TODO: don't hardcode here - constexpr EXP_16 = USE_BF16 ? 8 : 5; + constexpr int EXP_16 = USE_BF16 ? 8 : 5; constexpr int BIAS_OFFSET = (int(1) << (EXP_16-1)) - (int(1) << (EXPONENT-1)); constexpr int BIAS = int(1) << BIAS_OFFSET; // - half* FP16_1 = reinterpret_cast(&PackedFP16Pair); - half* FP16_2 = FP16_1 + 1; uint32_t output; - half* output_half_ptr = reinterpret_cast(&output); - // TODO: should these ops be bf16-specific? - output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); - output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); + if (USE_BF16) { + __nv_bfloat16* FP16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedFP16Pair); + __nv_bfloat16* FP16_2 = FP16_1 + 1; + __nv_bfloat16* output_half_ptr = reinterpret_cast<__nv_bfloat16*>(&output); + // TODO: should not do scale conversion here (scale parameter should be bfloat16) + __nv_bfloat16 Scale_bf16 = __float2bfloat16(__half2float(Scale)); + // TODO: it might be faster to do both ops (for [0] and [1]) in one op using __hmul2 + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2bfloat16(1.0f*BIAS)), Scale_bf16); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2bfloat16(1.0f*BIAS)), Scale_bf16); + } else { + half* FP16_1 = reinterpret_cast(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; + half* output_half_ptr = reinterpret_cast(&output); + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); + } return output; } From 125f17c2a057cfc6a84ac1418f3e2249a34574b4 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Mon, 14 Oct 2024 10:56:00 +0200 Subject: [PATCH 04/30] Template progress --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 158 ++++++++++-------- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 62 ++++--- .../csrc/cuda/fp6_llm/kernel_reduction.cuh | 22 ++- torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh | 5 +- torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 11 +- torchao/csrc/cuda/fp6_llm/utils_core.cuh | 22 +-- torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 33 ++-- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 59 ++++--- 8 files changed, 210 insertions(+), 162 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 978925a3f7..ed856b21b0 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -41,16 +41,16 @@ inline bool isSM75GPU() { return (major == 7) && (minor == 5); } -template -static void Kernel_Ex(cudaStream_t stream, - const uint4 *Weight, - const half *Scales, - const half *B, - OutputDataType *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - int Split_K) +template +static void Kernel_Ex(cudaStream_t stream, + const uint4 *Weight, + const InputDataType *Scales, + const InputDataType *B, + OutputDataType *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) { #ifdef DEBUG_MODE printf("\n"); @@ -59,7 +59,7 @@ static void Kernel_Ex(cudaStream_t stream, printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); #endif static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE); - cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1); @@ -70,22 +70,23 @@ static void Kernel_Ex(cudaStream_t stream, GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); printf("\n"); #endif - QUANT_GEMM_Kernel<<>> + QUANT_GEMM_Kernel<<>> (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } -template -cudaError_t fpx_linear_kernel(cudaStream_t stream, - const uint4 *Weight, - const half *Scales, - const half *B, - half *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - int Split_K) +template +cudaError_t fpx_linear_kernel(cudaStream_t stream, + const uint4 *Weight, + const InputDataType *Scales, + const InputDataType *B, + InputDataType *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + int Split_K) { + static_assert(std::is_same::value || std::is_same::value, "Type must be float or __nv_bfloat16"); assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); assert(N_Global>0); @@ -102,37 +103,37 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. if (Split_K == 1) { - Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } else { - Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); + Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); } } else { if (Split_K == 1) { switch (N_PowerOf2) { - case 8: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } - Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } } else { switch (N_PowerOf2) { - case 8: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } - Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } } } @@ -141,7 +142,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, // Reduction for SplitK dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); dim3 BlockDim(WARP_SIZE, 1, 1); - SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); + SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); } return cudaGetLastError(); @@ -153,6 +154,26 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, #include #include +// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using torch_t = at::Half; \ + using nv_t = half; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using torch_t = at::BFloat16; \ + using nv_t = __nv_bfloat16; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + namespace torchao { // MODIFICATION NOTE: dtype of _weights is changed to uint8 /* @@ -163,12 +184,12 @@ Standard definition of linear layer: Out = In * trans(W), where In, Out, and After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. [Inputs] - _in_feats: tensor of shape [B, IC]; // half + _in_feats: tensor of shape [B, IC]; // half or bf16 _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words contains 8 FPx weights. - _scales: tensor of shape [OC]; // half + _scales: tensor of shape [OC]; // half or bf16 splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] - _out_feats: tensor of shape [B, OC]; // half + _out_feats: tensor of shape [B, OC]; // half or bf16 */ torch::Tensor fp_eXmY_linear_forward_cuda( int64_t EXPONENT, @@ -184,18 +205,14 @@ torch::Tensor fp_eXmY_linear_forward_cuda( int num_out_channels = _weights.size(0); TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. + TORCH_CHECK(_in_feats.dtype() == _scales.dtype()); // int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; - // Input Tensors - auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto scales = reinterpret_cast(_scales.data_ptr()); - // Output Tensors + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); @@ -205,26 +222,33 @@ torch::Tensor fp_eXmY_linear_forward_cuda( // this fixes problem with CUDA graphs when used with torch.compile() auto stream = at::cuda::getCurrentCUDAStream(); - // officially supported in Quant-LLM - if (EXPONENT == 3 && MANTISSA == 2) - fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 2 && MANTISSA == 2) - fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - - // experimental - else if (EXPONENT == 2 && MANTISSA == 3) - fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 3 && MANTISSA == 1) - fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 2 && MANTISSA == 1) - // fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 3 && MANTISSA == 0) - // fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 2 && MANTISSA == 0) - // fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - - else - TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); + DISPATCH_HALF_AND_BF16(_in_feats.scalar_type(), "fpx_linear_kernel", [&] { + auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + // officially supported in Quant-LLM + if (EXPONENT == 3 && MANTISSA == 2) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 2) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + // experimental + else if (EXPONENT == 2 && MANTISSA == 3) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 1) + fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 1) + // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 3 && MANTISSA == 0) + // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 0) + // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + else + TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); + }); return _out_feats; } diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 560ca107cc..3c43e0c4eb 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -46,9 +46,9 @@ * B: col major, FP16 * C: col major, FP16 */ - template -__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, - const half *B, + template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const InputDataType* Scales, + const InputDataType *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) @@ -67,9 +67,16 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0); const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0); // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned - extern __shared__ __align__(128) half smem[]; - half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles - __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes + // extern __shared__ __align__(128) InputDataType smem[]; + // TODO: this is a weird hack (defining smem as 'half' type and then casting + // it to the template type), but for some reason the compiler complains about + // redeclaration of smem if I don't do this. It seems to have something to do + // with calling `cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);` + // in `Kernel_Ex` + extern __shared__ __align__(128) half smem1[]; + InputDataType* smem = reinterpret_cast(smem1); + InputDataType (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles + __shared__ InputDataType QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) @@ -117,21 +124,21 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4; // Pre-fetch of A tile for(int i=0; i(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); - if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); - if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; } // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// - const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; - const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; - CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); + const InputDataType* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; + const InputDataType* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// - const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + const InputDataType *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); + CopyFromGlobalToShared (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); BTile_GPTR += TilingConfig::TILE_K; } // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// @@ -151,9 +158,10 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales - ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); + ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// - initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); + constexpr bool USE_BF16 = std::is_same::value; + initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) @@ -170,29 +178,29 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 // Trible-Buffer for B Tile // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR. - half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + InputDataType (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + InputDataType (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + InputDataType (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; // Copying A tile from Global to Register, Bypassing L1, using double-buffer - if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); - if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); - if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory - CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); #if __CUDA_ARCH__ >= 800 cp_async_group_commit(); #endif - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); // Barriers and Synchronizations #if __CUDA_ARCH__ >= 800 cp_async_wait_group(); #endif __syncthreads(); - core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); + core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); // Updating global PTRs WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index c0e7c1918a..d5742de870 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -36,17 +36,20 @@ #include #include +#include #include #define REDUCTION_ELEMENT_PER_THREADBLOCK 256 #define HALF_PER_128BIT 8 -__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) +template +__global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) { - half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; - float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + static_assert(std::is_same::value || std::is_same::value, "Type must be float or __nv_bfloat16"); + T* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + T* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; // Initializing Thread-Local Results float Results[HALF_PER_128BIT]; #pragma unroll @@ -58,6 +61,11 @@ __global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_G THREAD_GPTR_R += M_Global * N_Global; } // Writing to global memory - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); + } else { // __nv_bfloat16> + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2bfloat16_rn(Results[i]); + } } diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh index c1d064f32a..7445db46c8 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh @@ -33,10 +33,11 @@ #include #include +#include #include -template -__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) +template +__device__ __forceinline__ void cp_async(T* smem_ptr, const T* global_ptr, bool pred_guard = true) { static_assert(SizeInBytes == 16, "Size is not supported"); unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index 9708222a88..7c1f066be8 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -35,6 +35,7 @@ #include #include +#include #include #include @@ -43,9 +44,9 @@ // MODIFICATION NOTE: to support MSVC // - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4] // - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR) -template +template __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4], - half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + T (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], int slice_id) { #ifdef DEBUG_MODE static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); @@ -82,17 +83,17 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) : "r"(smem_local_ptr)); - smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(T); } } } // MODIFICATION NOTE: to support MSVC, the function signature is changed from // MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b). +template __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b) { - constexpr bool USE_BF16 = true; // TODO: don't hardcode here #if __CUDA_ARCH__ == 750 // m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops. asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" @@ -115,7 +116,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); #else - if (USE_BF16) { + if constexpr (USE_BF16) { asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" "{ %0, %1, %2, %3}," "{ %4, %5, %6, %7 }," diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh index 7a6cd36a46..070490a77a 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -35,13 +35,13 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read, uint32_t* __restrict__ A_2BIT_SPTR_read, uint32_t* __restrict__ A_4BIT_SPTR_read, - half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + T (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { // 1+2+4 weight split @@ -57,19 +57,19 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t ( if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1BIT_SPTR_read, 0); if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2BIT_SPTR_read, 0); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4BIT_SPTR_read, 0); - Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers + Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read, uint32_t* __restrict__ A_2bit_SPTR_read, uint32_t* __restrict__ A_4bit_SPTR_read, - half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + T (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { @@ -98,13 +98,13 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG #pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { if(TilingConfig::WARP_COL_MMA_TENSORS==1) { - MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); + MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); } else { #pragma unroll for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 } } } @@ -116,8 +116,8 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1bit_SPTR_read, slice_id); if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2bit_SPTR_read, slice_id); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4bit_SPTR_read, slice_id); - Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers + Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } template diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index f2af30733f..ca02c1d7d3 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -29,7 +29,7 @@ * Copying A1/A2 from global memory to shared memory. * Usually 1024 or 2048 Bytes */ -template +template __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) { @@ -37,23 +37,23 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); #endif int lane_id = threadIdx.x % WARP_SIZE; - half* SPTR_HALF = reinterpret_cast(SPTR); - const half* GPTR_HALF = reinterpret_cast(GPTR); - SPTR_HALF += lane_id*8; - GPTR_HALF += lane_id*8; + T* SPTR_T = reinterpret_cast(SPTR); + const T* GPTR_T = reinterpret_cast(GPTR); + SPTR_T += lane_id*8; + GPTR_T += lane_id*8; #pragma unroll for(int i=0; i(SPTR_HALF); - const float4* GPTR_VEC = reinterpret_cast(GPTR_HALF); + float4* SPTR_VEC = reinterpret_cast(SPTR_T); + const float4* GPTR_VEC = reinterpret_cast(GPTR_T); SPTR_VEC[0] = GPTR_VEC[0]; } #else - cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard); + cp_async( SPTR_T, GPTR_T, pred_guard); #endif - SPTR_HALF += 256; // Forward 512 Bytes - GPTR_HALF += 256; // Forward 512 Bytes + SPTR_T += 256; // Forward 512 Bytes + GPTR_T += 256; // Forward 512 Bytes } } @@ -61,8 +61,9 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, /* * Copying 64 Quant Scales (FP16) from global memory to shared memory. */ -__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, - const half* GPTR_A_Scales) { +template +__device__ __forceinline__ void CopyFromGlobalToShared_Scales(T* SPTR_QuantScales, + const T* GPTR_A_Scales) { int lane_id = threadIdx.x % WARP_SIZE; int Offset_Shared = lane_id*2; int Offset_Global = lane_id/4 + (lane_id%4)*16; @@ -75,9 +76,9 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc * (2) Copying 64 rows * X columns of FP16 values, originally in column major * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads */ -template -__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - const half* GlobalPTR, +template +__device__ __forceinline__ void CopyFromGlobalToShared(T (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + const T* GlobalPTR, const int GlobalStride, const int NumOfLinesLeft, // To support arbitrary N dimensions. bool Pred = true) { @@ -101,7 +102,7 @@ __device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ Shar SharedPtrVec[0] = GlobalPtrVec[0]; } #else - cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + cp_async( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); #endif GlobalPTR += NumOfGroups * GlobalStride; SharedPTR += NumOfGroups; diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index c97c549bc1..baf7638878 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -28,10 +28,9 @@ * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -template +template __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) { // - constexpr bool USE_BF16 = true; // TODO: don't hardcode here constexpr int RIGHT_SHIFT = USE_BF16 ? 8 - EXPONENT : 5 - EXPONENT; constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; // NB: arithmetic shift, not logical @@ -48,35 +47,39 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { - constexpr bool USE_BF16 = true; // TODO: don't hardcode here - constexpr int EXP_16 = USE_BF16 ? 8 : 5; - constexpr int BIAS_OFFSET = (int(1) << (EXP_16-1)) - (int(1) << (EXPONENT-1)); + constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1)); constexpr int BIAS = int(1) << BIAS_OFFSET; // + half* FP16_1 = reinterpret_cast(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; uint32_t output; - if (USE_BF16) { - __nv_bfloat16* FP16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedFP16Pair); - __nv_bfloat16* FP16_2 = FP16_1 + 1; - __nv_bfloat16* output_half_ptr = reinterpret_cast<__nv_bfloat16*>(&output); - // TODO: should not do scale conversion here (scale parameter should be bfloat16) - __nv_bfloat16 Scale_bf16 = __float2bfloat16(__half2float(Scale)); - // TODO: it might be faster to do both ops (for [0] and [1]) in one op using __hmul2 - output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2bfloat16(1.0f*BIAS)), Scale_bf16); - output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2bfloat16(1.0f*BIAS)), Scale_bf16); - } else { - half* FP16_1 = reinterpret_cast(&PackedFP16Pair); - half* FP16_2 = FP16_1 + 1; - half* output_half_ptr = reinterpret_cast(&output); - output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); - output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); - } + half* output_half_ptr = reinterpret_cast(&output); + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); + return output; +} + +template +__device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) { + constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1)); + constexpr int BIAS = int(1) << BIAS_OFFSET; + // + __nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair); + __nv_bfloat16* BF16_2 = BF16_1 + 1; + uint32_t output; + __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); + // TODO: it might be faster to do both ops (for [0] and [1]) in one op using __hmul2 + // TODO: this multiplication with bias is potentially problematic and might also slow things down. + // Note that BIAS has a value of 2^120 -> what happens when the int overflows? + output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(1.0f*BIAS)), Scale); + output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(1.0f*BIAS)), Scale); return output; } // MODIFICATION NOTE: to support MSVC // - u_int32_t __restrict__ Reg[][4] is changed to below. // - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for read_RPTR_2bit and read_RPTR_4bit -template +template __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4], uint32_t * __restrict__ read_RPTR_1bit, uint32_t * __restrict__ read_RPTR_2bit, @@ -92,7 +95,8 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) uint32_t *Frag_PTR_1bit = read_RPTR_1bit; uint32_t *Frag_PTR_2bit = read_RPTR_2bit; uint32_t *Frag_PTR_4bit = read_RPTR_4bit; - half *Scale_RPTR = reinterpret_cast(Scales); + using scalar_t = typename std::conditional::type; + scalar_t *Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for(int i=0; i<8; i++) { @@ -122,13 +126,13 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) // Packed_FP6 now contains 4x 1234 5600 // uint32_t out1, out2; - FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); + FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); // out1 now contains 2 FP16 values, as shown by R1 in figure 6 // out2 now contains 2 FP16 values, as shown by R2 in figure 6 // - *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16 scales + *OutputRegs = MultScale(out1, Scale_RPTR[0]); // Muliply FP16 scales OutputRegs += 1; - *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16 scales + *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16 scales OutputRegs += 1; // Updating offset for FP16 scales for every two iterations if(i%2==1) Scale_RPTR += 2; @@ -139,7 +143,8 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) /* * */ -__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { +template +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, T* WARP_SPTR_Scales) { int lane_id = threadIdx.x % WARP_SIZE; uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); uint32_t tmpReg = SPTR_uint[lane_id]; From b3c3be0a9c756f91ef60bead08bcfdaafd3dfe95 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Fri, 18 Oct 2024 12:40:35 +0200 Subject: [PATCH 05/30] Modify fpx quant logic to include bf16 --- torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh | 2 ++ torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh | 2 ++ torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 2 ++ torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh | 3 +++ torchao/dtypes/affine_quantized_tensor.py | 10 +++++----- torchao/dtypes/floatx/floatx.py | 3 ++- torchao/ops.py | 4 ++-- torchao/quantization/quant_primitives.py | 3 ++- 8 files changed, 20 insertions(+), 9 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index d5742de870..b8779b987c 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -36,7 +36,9 @@ #include #include +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include +// #endif #include #define REDUCTION_ELEMENT_PER_THREADBLOCK 256 diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh index 7445db46c8..7ce05cefc8 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh @@ -33,7 +33,9 @@ #include #include +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include +// #endif #include template diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index 7c1f066be8..cbaf8b14f6 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -35,7 +35,9 @@ #include #include +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include +// #endif #include #include diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index baf7638878..f4461e3c36 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -20,7 +20,10 @@ #include #include +// TODO: can cuda_bf16 be imported for SM75? How to guard against this? The guard below does not work outside of device code +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include +// #endif #include /* diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 75d178fb50..34156697f2 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1612,13 +1612,13 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y -def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias): +def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import FloatxTensorCoreLayout return ( # input is native float32 tensor not is_traceable_wrapper_subclass(input_tensor) and input_tensor.is_floating_point() and - input_tensor.dtype == torch.float16 and + input_tensor.dtype in (torch.float16, torch.bfloat16) and # weight is floatx Tensor isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and @@ -1636,7 +1636,7 @@ def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias): ) ) -def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): +def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear @@ -1644,7 +1644,7 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): weight = weight_tensor out_dim, in_dim = weight.shape - act_reshaped = act.view(-1, in_dim).half() + act_reshaped = act.view(-1, in_dim) # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py bsize = act_reshaped.shape[0] @@ -1804,7 +1804,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl), + (_linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl), (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx.py index f862106373..a4745e9315 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -128,11 +128,12 @@ def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, exp_bias = _ONES_TABLE[ebits - 1] max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + dtype = tensor.dtype tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) - return tensor_tc_floatx, scale.half() + return tensor_tc_floatx, scale.to(dtype) # inverse of _pack_tc_floatx() diff --git a/torchao/ops.py b/torchao/ops.py index 79c02dfd85..fa8ad7fe89 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -55,11 +55,11 @@ def _( splitK: int = 1, ) -> Tensor: torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") - torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") + torch._check(_in_feats.dtype in (torch.float16, torch.bfloat16), lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}") torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") torch._check(_weights.dtype is torch.uint8, lambda: f"weight must be UINT8, got {_weights.dtype}") torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") - torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") + torch._check(_scales.dtype in (torch.float16, torch.bfloat16), lambda: f"scale must be FP16 or BF16, got {_scales.dtype}") BS, IC = _in_feats.shape OC, _ = _weights.shape diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index dfd3bcaad8..a8ac533740 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1003,9 +1003,10 @@ def choose_qparams_affine_floatx(tensor: torch.Tensor, ebits: int, mbits: int) - exp_bias = _ONES_TABLE[ebits - 1] max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + dtype = tensor.dtype tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - return scale.half() + return scale.to(dtype) def quantize_affine_floatx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: """Quantizes the float32 high precision floating point tensor to low precision floating point number and From f82876326569fd89a3c25ee3a8477c238ce44c09 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Fri, 18 Oct 2024 13:01:15 +0200 Subject: [PATCH 06/30] Add tests for FP6 BF16 --- test/dtypes/test_floatx.py | 7 ++++--- test/test_ops.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 93dc7515d9..875a8c8d5e 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -91,16 +91,17 @@ def test_to_copy_device(self, ebits, mbits): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) + @parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.skipif(is_fbcode(), reason="broken in fbcode") - def test_fpx_weight_only(self, ebits, mbits, bias): + def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 device = "cuda" - linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half) + linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype) fpx_linear = copy.deepcopy(linear) quantize_(fpx_linear, fpx_weight_only(ebits, mbits)) - x = torch.randn(N, IC, device=device, dtype=torch.half) + x = torch.randn(N, IC, device=device, dtype=dtype) expected = fpx_linear(x) actual = torch.compile(fpx_linear, fullgraph=True)(x) # somehow compile now changes the result a bit diff --git a/test/test_ops.py b/test/test_ops.py index 31000eafc2..c174cec794 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -33,22 +33,23 @@ class TestOps(TestCase): - def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): + def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype): # Randomly initialize each byte nbits = 1 + ebits + mbits floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) - scale = torch.rand(OC).half() + 0.5 - fp16_act = torch.rand(BS, IC).half() + 0.5 + scale = torch.rand(OC).to(dtype) + 0.5 + fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 return floatx_weight.to(device), scale.to(device), fp16_act.to(device) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - def test_quant_llm_linear(self, ebits, mbits): + @parametrize("dtype", [torch.half, torch.bfloat16]) + def test_quant_llm_linear(self, ebits, mbits, dtype): BS = 2 OC = 256 IC = 256 splitK = 1 - floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda") + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) # smoke test torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) @@ -60,13 +61,14 @@ def test_quant_llm_linear(self, ebits, mbits): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): + @parametrize("dtype", [torch.half, torch.bfloat16]) + def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype): # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py - floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda") + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype) results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) - fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half() + fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype) results_fp16 = fp16_act @ fp16_weight.T error = (results_floatx - results_fp16).abs().mean() From ff2c6e8c9808eb48cf488334649d4a554c30b43f Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:16:01 +0200 Subject: [PATCH 07/30] Use type punning for large exponent multiplication --- benchmarks/benchmark_fp6.py | 46 +++++++----- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 73 +++++++++++++++++-- 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index c5ae9bfef6..eb0bd1f5f3 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -7,31 +7,43 @@ from tqdm import tqdm -def benchmark(m: int, k: int, n: int, use_bf16=False): - dtype = torch.bfloat16 if use_bf16 else torch.half - float_data = torch.randn(n, k, dtype=dtype, device="cuda") - fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2)) - fp16_weight = fp6_weight.dequantize(dtype) - - fp16_act = torch.randn(m, k, dtype=dtype, device="cuda") - fp6_output = F.linear(fp16_act, fp6_weight) +def benchmark(m: int, k: int, n: int): + float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda") + float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2)) + fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2)) + fp16_weight = fp6_weight_fp16.dequantize(torch.float16) + bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16) + + fp16_act = torch.randn(m, k, dtype=torch.float16, device="cuda") + bf16_act = fp16_act.to(torch.bfloat16) + fp6_output_fp16 = F.linear(fp16_act, fp6_weight_fp16) + fp6_output_bf16 = F.linear(bf16_act, fp6_weight_bf16) fp16_output = F.linear(fp16_act, fp16_weight) + bf16_output = F.linear(bf16_act, bf16_weight) - fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight) fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) + bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight) + fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16) + fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16) # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness - correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 + correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 + correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2 return { "m": m, "k": k, "n": n, - "fp6_latency (ms)": fp6_time, - f"{'bf16' if use_bf16 else 'fp16'}_latency (ms)": fp16_time, - "speedup (d/s)": fp16_time / fp6_time, - "correct": correct, + "fp6-fp16 latency (ms)": fp6_time_fp16, + "fp16 latency (ms)": fp16_time, + "speedup fp16": fp16_time / fp6_time_fp16, + "correct fp16": correct_fp16, + "fp6-bf16 latency (ms)": fp6_time_bf16, + "bf16 latency (ms)": bf16_time, + "speedup bf16": bf16_time / fp6_time_bf16, + "correct bf16": correct_bf16, } @@ -40,13 +52,11 @@ def benchmark(m: int, k: int, n: int, use_bf16=False): k_vals = (8192, 8192, 8192, 28672) n_vals = (8192, 10240, 57344, 8192) - use_bf16 = True - results = [] - for m in tqdm([1 << i for i in range(10)]): + for m in tqdm([1 << i for i in range(5)]): # TODO: reset to 10 for n, k in zip(n_vals, k_vals): - results.append(benchmark(m, k, n, use_bf16=use_bf16)) + results.append(benchmark(m, k, n)) df = pd.DataFrame(results) df.to_csv("fp6_llm_benchmark_results.csv", index=False) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index f4461e3c36..128c8559ad 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -46,6 +46,43 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, *In = (*In) << 8; *Out2 = *In & 0x80008000; *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; + + if constexpr (false && USE_BF16) { + // Add exponent bias + constexpr int MASK1_EXP = 0x80000000; + constexpr int MASK2_EXP = MASK1_EXP >> (EXPONENT + RIGHT_SHIFT); // NB: arithmetic shift, not logical + constexpr int MASK3_EXP = MASK2_EXP & 0x7fffffff; + constexpr int MASK_EXP = MASK3_EXP | MASK3_EXP >> 16; + // Extract exponents bits + union { + uint32_t u32; + uint8_t u8[4]; + uint16_t u16[2]; + } tmp1, tmp2; + /* + constexpr uint16_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); // 124 = 0x7c + tmp1.u32 = (*Out1 & MASK_EXP) >> 7; + tmp2.u32 = (*Out2 & MASK_EXP) >> 7; + tmp1.u16[0] += BIAS_OFFSET; + tmp1.u16[1] += BIAS_OFFSET; + tmp2.u16[0] += BIAS_OFFSET; + tmp2.u16[1] += BIAS_OFFSET; + *Out1 = (*Out1 & ~MASK_EXP) | (tmp1.u32 << 7); + *Out2 = (*Out2 & ~MASK_EXP) | (tmp2.u32 << 7); + */ + // /* + constexpr uint8_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); // 124 = 0x7c + tmp1.u32 = (*Out1 & MASK_EXP) << 1; + tmp2.u32 = (*Out2 & MASK_EXP) << 1; + // NB: little endian + tmp1.u8[3] += BIAS_OFFSET; + tmp1.u8[1] += BIAS_OFFSET; + tmp2.u8[3] += BIAS_OFFSET; + tmp2.u8[1] += BIAS_OFFSET; + *Out1 = (*Out1 & ~MASK_EXP) | (tmp1.u32 >> 1); + *Out2 = (*Out2 & ~MASK_EXP) | (tmp2.u32 >> 1); + // */ + } } template @@ -65,17 +102,41 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scal template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) { constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1)); - constexpr int BIAS = int(1) << BIAS_OFFSET; - // __nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair); __nv_bfloat16* BF16_2 = BF16_1 + 1; uint32_t output; __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); + if constexpr (false) { + // Bfloat16 exponent bias is 127, which would lead to multiplication with + // 2^127, which would lead to overflow. Instead, we decompose the exponent + // into smaller values and multiply several times. + __nv_bfloat16 tmp1 = *BF16_1; + __nv_bfloat16 tmp2 = *BF16_2; + // FIXME: only works for exponent=3 right now. + // Note that for exponent=3, BIAS_OFFSET = 2^7 - 2^2 = 124 = 4*31 + const __nv_bfloat16 BIAS = __float2bfloat16(1.0f * (uint32_t(1) << BIAS_OFFSET / 4)); + #pragma unroll + for (int i = 0; i < 4; i++) { + tmp1 = __hmul(tmp1, BIAS); + tmp2 = __hmul(tmp2, BIAS); + } + output_bf16_ptr[0] = __hmul( tmp1, Scale); + output_bf16_ptr[1] = __hmul( tmp2, Scale); + } else { + // Bfloat16 exponent bias is 127, which would lead to multiplication with + // 2^127, which would lead to overflow. Instead, we use type punning to + // directly construct a float with a large exponent. + union { + uint32_t u32; + float f; + } tmp; + tmp.u32 = (BIAS_OFFSET + 127) << 23; // 127=exponent bias, 23=mantissa + output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(tmp.f)), Scale); + output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(tmp.f)), Scale); + } // TODO: it might be faster to do both ops (for [0] and [1]) in one op using __hmul2 - // TODO: this multiplication with bias is potentially problematic and might also slow things down. - // Note that BIAS has a value of 2^120 -> what happens when the int overflows? - output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(1.0f*BIAS)), Scale); - output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(1.0f*BIAS)), Scale); + // output_bf16_ptr[0] = __hmul( *BF16_1, Scale); + // output_bf16_ptr[1] = __hmul( *BF16_2, Scale); return output; } From 4304dcc904e96fdfeb43ab5622d39882375fc84a Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:10:16 +0200 Subject: [PATCH 08/30] Fix some TODOs --- benchmarks/benchmark_fp6.py | 2 +- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 13 +++----- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 31 ++++++++++--------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index eb0bd1f5f3..25967baa25 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -54,7 +54,7 @@ def benchmark(m: int, k: int, n: int): results = [] - for m in tqdm([1 << i for i in range(5)]): # TODO: reset to 10 + for m in tqdm([1 << i for i in range(10)]): for n, k in zip(n_vals, k_vals): results.append(benchmark(m, k, n)) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 3c43e0c4eb..2c4fb000a9 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -67,14 +67,11 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const InputDataType* Scal const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0); const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0); // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned - // extern __shared__ __align__(128) InputDataType smem[]; - // TODO: this is a weird hack (defining smem as 'half' type and then casting - // it to the template type), but for some reason the compiler complains about - // redeclaration of smem if I don't do this. It seems to have something to do - // with calling `cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);` - // in `Kernel_Ex` - extern __shared__ __align__(128) half smem1[]; - InputDataType* smem = reinterpret_cast(smem1); + extern __shared__ __align__(128) half smem_[]; + // Defining smem like this is necessary for templated kernels (defining it as + // a fixed type and then casting it to the template type). See + // https://leimao.github.io/blog/CUDA-Shared-Memory-Templated-Kernel/ for details. + InputDataType* smem = reinterpret_cast(smem_); InputDataType (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles __shared__ InputDataType QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 128c8559ad..05cd67504a 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -48,19 +48,21 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; if constexpr (false && USE_BF16) { - // Add exponent bias + // This snippet adds the exponent bias directly to the exponent bits instead of constructing a bf16 type and + // multiplying with 2^bias in the `MultScale()` function. However, this option is slower, so I don't use it. + // Figure 6 in the FP6 paper provides a helpful visualization of how the FP6 bits are layed out. constexpr int MASK1_EXP = 0x80000000; constexpr int MASK2_EXP = MASK1_EXP >> (EXPONENT + RIGHT_SHIFT); // NB: arithmetic shift, not logical constexpr int MASK3_EXP = MASK2_EXP & 0x7fffffff; constexpr int MASK_EXP = MASK3_EXP | MASK3_EXP >> 16; - // Extract exponents bits + // Extract exponent bits union { uint32_t u32; uint8_t u8[4]; uint16_t u16[2]; } tmp1, tmp2; /* - constexpr uint16_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); // 124 = 0x7c + constexpr uint16_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); tmp1.u32 = (*Out1 & MASK_EXP) >> 7; tmp2.u32 = (*Out2 & MASK_EXP) >> 7; tmp1.u16[0] += BIAS_OFFSET; @@ -70,18 +72,17 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, *Out1 = (*Out1 & ~MASK_EXP) | (tmp1.u32 << 7); *Out2 = (*Out2 & ~MASK_EXP) | (tmp2.u32 << 7); */ - // /* - constexpr uint8_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); // 124 = 0x7c + constexpr uint8_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); tmp1.u32 = (*Out1 & MASK_EXP) << 1; tmp2.u32 = (*Out2 & MASK_EXP) << 1; - // NB: little endian + // Add exponent bias to exponent bits (NB: little endian) tmp1.u8[3] += BIAS_OFFSET; tmp1.u8[1] += BIAS_OFFSET; tmp2.u8[3] += BIAS_OFFSET; tmp2.u8[1] += BIAS_OFFSET; + // Insert new exponent bits *Out1 = (*Out1 & ~MASK_EXP) | (tmp1.u32 >> 1); *Out2 = (*Out2 & ~MASK_EXP) | (tmp2.u32 >> 1); - // */ } } @@ -107,9 +108,10 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo uint32_t output; __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); if constexpr (false) { - // Bfloat16 exponent bias is 127, which would lead to multiplication with - // 2^127, which would lead to overflow. Instead, we decompose the exponent - // into smaller values and multiply several times. + // Exponent bias is 124, which would lead to multiplication with 2^124, + // which would lead to overflow when stored in a 32 or 64-bit type. + // Instead, we decompose the exponent into smaller values and multiply + // several times. __nv_bfloat16 tmp1 = *BF16_1; __nv_bfloat16 tmp2 = *BF16_2; // FIXME: only works for exponent=3 right now. @@ -123,9 +125,10 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo output_bf16_ptr[0] = __hmul( tmp1, Scale); output_bf16_ptr[1] = __hmul( tmp2, Scale); } else { - // Bfloat16 exponent bias is 127, which would lead to multiplication with - // 2^127, which would lead to overflow. Instead, we use type punning to - // directly construct a float with a large exponent. + // Exponent bias is 124, which would lead to multiplication with 2^124, + // which would lead to overflow when stored in a 32 or 64-bit type. + // Instead, we use type punning to directly construct a float with a + // large exponent. union { uint32_t u32; float f; @@ -134,7 +137,7 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(tmp.f)), Scale); output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(tmp.f)), Scale); } - // TODO: it might be faster to do both ops (for [0] and [1]) in one op using __hmul2 + // Use this if exponent bias is already added in `FPx_FP16_Cast_4Way` // output_bf16_ptr[0] = __hmul( *BF16_1, Scale); // output_bf16_ptr[1] = __hmul( *BF16_2, Scale); return output; From 2d00a3aa3b7c783d4b884b85e05209dc8413765a Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:12:43 +0200 Subject: [PATCH 09/30] Remove option to add exponent bias directly to the exponent bits This approach is (much) slower than multiplying by 2^bias after the fact, so that's why it's not usable --- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 41 ------------------- 1 file changed, 41 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 05cd67504a..87a15fa726 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -46,44 +46,6 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, *In = (*In) << 8; *Out2 = *In & 0x80008000; *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; - - if constexpr (false && USE_BF16) { - // This snippet adds the exponent bias directly to the exponent bits instead of constructing a bf16 type and - // multiplying with 2^bias in the `MultScale()` function. However, this option is slower, so I don't use it. - // Figure 6 in the FP6 paper provides a helpful visualization of how the FP6 bits are layed out. - constexpr int MASK1_EXP = 0x80000000; - constexpr int MASK2_EXP = MASK1_EXP >> (EXPONENT + RIGHT_SHIFT); // NB: arithmetic shift, not logical - constexpr int MASK3_EXP = MASK2_EXP & 0x7fffffff; - constexpr int MASK_EXP = MASK3_EXP | MASK3_EXP >> 16; - // Extract exponent bits - union { - uint32_t u32; - uint8_t u8[4]; - uint16_t u16[2]; - } tmp1, tmp2; - /* - constexpr uint16_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); - tmp1.u32 = (*Out1 & MASK_EXP) >> 7; - tmp2.u32 = (*Out2 & MASK_EXP) >> 7; - tmp1.u16[0] += BIAS_OFFSET; - tmp1.u16[1] += BIAS_OFFSET; - tmp2.u16[0] += BIAS_OFFSET; - tmp2.u16[1] += BIAS_OFFSET; - *Out1 = (*Out1 & ~MASK_EXP) | (tmp1.u32 << 7); - *Out2 = (*Out2 & ~MASK_EXP) | (tmp2.u32 << 7); - */ - constexpr uint8_t BIAS_OFFSET = (1 << (8-1)) - (1 << (EXPONENT-1)); - tmp1.u32 = (*Out1 & MASK_EXP) << 1; - tmp2.u32 = (*Out2 & MASK_EXP) << 1; - // Add exponent bias to exponent bits (NB: little endian) - tmp1.u8[3] += BIAS_OFFSET; - tmp1.u8[1] += BIAS_OFFSET; - tmp2.u8[3] += BIAS_OFFSET; - tmp2.u8[1] += BIAS_OFFSET; - // Insert new exponent bits - *Out1 = (*Out1 & ~MASK_EXP) | (tmp1.u32 >> 1); - *Out2 = (*Out2 & ~MASK_EXP) | (tmp2.u32 >> 1); - } } template @@ -137,9 +99,6 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(tmp.f)), Scale); output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(tmp.f)), Scale); } - // Use this if exponent bias is already added in `FPx_FP16_Cast_4Way` - // output_bf16_ptr[0] = __hmul( *BF16_1, Scale); - // output_bf16_ptr[1] = __hmul( *BF16_2, Scale); return output; } From ceaed349a622218db10afdc6e89903b9e82e990e Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:15:54 +0200 Subject: [PATCH 10/30] Reformat --- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 87a15fa726..0f607918da 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -69,15 +69,16 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo __nv_bfloat16* BF16_2 = BF16_1 + 1; uint32_t output; __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); - if constexpr (false) { - // Exponent bias is 124, which would lead to multiplication with 2^124, - // which would lead to overflow when stored in a 32 or 64-bit type. - // Instead, we decompose the exponent into smaller values and multiply - // several times. + // Multiply with exponent bias here. The exponent bias is 124, which would + // lead to multiplication with 2^124, which would lead to overflow when + // stored in a 32 or 64-bit type. There are two options: (1) decompose the + // exponent bias into smaller values, or (2) type punning (current choice). + if constexpr (false) { // option 1 (decomposition) + // Decompose the exponent bias into smaller values and multiply several times. __nv_bfloat16 tmp1 = *BF16_1; __nv_bfloat16 tmp2 = *BF16_2; - // FIXME: only works for exponent=3 right now. // Note that for exponent=3, BIAS_OFFSET = 2^7 - 2^2 = 124 = 4*31 + // NOTE: only works for exponent=3 right now. const __nv_bfloat16 BIAS = __float2bfloat16(1.0f * (uint32_t(1) << BIAS_OFFSET / 4)); #pragma unroll for (int i = 0; i < 4; i++) { @@ -86,11 +87,8 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo } output_bf16_ptr[0] = __hmul( tmp1, Scale); output_bf16_ptr[1] = __hmul( tmp2, Scale); - } else { - // Exponent bias is 124, which would lead to multiplication with 2^124, - // which would lead to overflow when stored in a 32 or 64-bit type. - // Instead, we use type punning to directly construct a float with a - // large exponent. + } else { // option 2 (type punning) + // Use type punning to directly construct a float with a large exponent. union { uint32_t u32; float f; From b532c51f5363fd25624905b6bf6c11045c08ead1 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:01:51 +0200 Subject: [PATCH 11/30] Cleanup --- torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 0f607918da..6db5882a53 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -147,18 +147,14 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) if(i%2==1) Frag_PTR_4bit++; else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; } - // Packed_FP6 now contains 4x 1234 5600 - // uint32_t out1, out2; FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); - // out1 now contains 2 FP16 values, as shown by R1 in figure 6 - // out2 now contains 2 FP16 values, as shown by R2 in figure 6 // - *OutputRegs = MultScale(out1, Scale_RPTR[0]); // Muliply FP16 scales + *OutputRegs = MultScale(out1, Scale_RPTR[0]); // Muliply FP16/BF16 scales OutputRegs += 1; - *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16 scales + *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16/BF16 scales OutputRegs += 1; - // Updating offset for FP16 scales for every two iterations + // Updating offset for FP16/BF16 scales for every two iterations if(i%2==1) Scale_RPTR += 2; } From e89274b04f94a974ba9b9531eaec40199ee8d93e Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:48:11 +0200 Subject: [PATCH 12/30] Fix alignment --- torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index b8779b987c..d886533811 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -48,10 +48,10 @@ template __global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) { static_assert(std::is_same::value || std::is_same::value, "Type must be float or __nv_bfloat16"); - T* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - T* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; - float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + T* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + T* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; // Initializing Thread-Local Results float Results[HALF_PER_128BIT]; #pragma unroll @@ -69,5 +69,5 @@ __global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Glob } else { // __nv_bfloat16> #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2bfloat16_rn(Results[i]); - } + } } From ac0fbe09eb9d9c1487c72240f47c015b13452778 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:40:13 +0200 Subject: [PATCH 13/30] Remove templated input type whenever possible --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 2 +- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 48 ++++++++++----------- torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh | 4 +- torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 6 +-- torchao/csrc/cuda/fp6_llm/utils_core.cuh | 12 +++--- torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 16 +++---- 6 files changed, 42 insertions(+), 46 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index ed856b21b0..1b801c34a3 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -71,7 +71,7 @@ static void Kernel_Ex(cudaStream_t stream, printf("\n"); #endif QUANT_GEMM_Kernel<<>> - (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + (Weight, reinterpret_cast(Scales), reinterpret_cast(B), C, M_Global, N_Global, K_Global, Split_K); } template diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 2c4fb000a9..ef5e2a4c2d 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -47,8 +47,8 @@ * C: col major, FP16 */ template -__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const InputDataType* Scales, - const InputDataType *B, +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, + const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) @@ -67,12 +67,8 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const InputDataType* Scal const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0); const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0); // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned - extern __shared__ __align__(128) half smem_[]; - // Defining smem like this is necessary for templated kernels (defining it as - // a fixed type and then casting it to the template type). See - // https://leimao.github.io/blog/CUDA-Shared-Memory-Templated-Kernel/ for details. - InputDataType* smem = reinterpret_cast(smem_); - InputDataType (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles + extern __shared__ __align__(128) half smem[]; + half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles __shared__ InputDataType QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); @@ -121,21 +117,21 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const InputDataType* Scal AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4; // Pre-fetch of A tile for(int i=0; i(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); - if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); - if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; } // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// - const InputDataType* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; + const InputDataType* TB_StartGPTR_A_Scale = reinterpret_cast(Scales) + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; const InputDataType* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// - const InputDataType *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + const InputDataType *BTile_GPTR = reinterpret_cast(B) + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); + CopyFromGlobalToShared (smem_array+i*TilingConfig::TILE_N, reinterpret_cast(BTile_GPTR), K_Global, NumColumnToCopy); BTile_GPTR += TilingConfig::TILE_K; } // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// @@ -158,7 +154,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const InputDataType* Scal ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// constexpr bool USE_BF16 = std::is_same::value; - initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); + initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) @@ -175,29 +171,29 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const InputDataType* Scal uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 // Trible-Buffer for B Tile // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR. - InputDataType (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - InputDataType (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - InputDataType (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; // Copying A tile from Global to Register, Bypassing L1, using double-buffer - if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); - if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); - if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory - CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + CopyFromGlobalToShared (write_SPTR, reinterpret_cast(BTile_GPTR), K_Global, NumColumnToCopy, GlobalCopy); #if __CUDA_ARCH__ >= 800 cp_async_group_commit(); #endif - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); // Barriers and Synchronizations #if __CUDA_ARCH__ >= 800 cp_async_wait_group(); #endif __syncthreads(); - core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); + core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); // Updating global PTRs WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh index 7ce05cefc8..736cbdd5c0 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh @@ -38,8 +38,8 @@ // #endif #include -template -__device__ __forceinline__ void cp_async(T* smem_ptr, const T* global_ptr, bool pred_guard = true) +template +__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) { static_assert(SizeInBytes == 16, "Size is not supported"); unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index cbaf8b14f6..2c6c4e43b9 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -46,9 +46,9 @@ // MODIFICATION NOTE: to support MSVC // - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4] // - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR) -template +template __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4], - T (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], int slice_id) { #ifdef DEBUG_MODE static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); @@ -85,7 +85,7 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) : "r"(smem_local_ptr)); - smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(T); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); } } } diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh index 070490a77a..4601edb397 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -35,13 +35,13 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read, uint32_t* __restrict__ A_2BIT_SPTR_read, uint32_t* __restrict__ A_4BIT_SPTR_read, - T (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { // 1+2+4 weight split @@ -58,18 +58,18 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t ( if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2BIT_SPTR_read, 0); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4BIT_SPTR_read, 0); Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read, uint32_t* __restrict__ A_2bit_SPTR_read, uint32_t* __restrict__ A_4bit_SPTR_read, - T (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { @@ -117,7 +117,7 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2bit_SPTR_read, slice_id); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4bit_SPTR_read, slice_id); Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers + B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } template diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index ca02c1d7d3..796719f2b9 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -29,7 +29,7 @@ * Copying A1/A2 from global memory to shared memory. * Usually 1024 or 2048 Bytes */ -template +template __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) { @@ -37,8 +37,8 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); #endif int lane_id = threadIdx.x % WARP_SIZE; - T* SPTR_T = reinterpret_cast(SPTR); - const T* GPTR_T = reinterpret_cast(GPTR); + half* SPTR_T = reinterpret_cast(SPTR); + const half* GPTR_T = reinterpret_cast(GPTR); SPTR_T += lane_id*8; GPTR_T += lane_id*8; #pragma unroll @@ -50,7 +50,7 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, SPTR_VEC[0] = GPTR_VEC[0]; } #else - cp_async( SPTR_T, GPTR_T, pred_guard); + cp_async<16>( SPTR_T, GPTR_T, pred_guard); #endif SPTR_T += 256; // Forward 512 Bytes GPTR_T += 256; // Forward 512 Bytes @@ -76,9 +76,9 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(T* SPTR_QuantScale * (2) Copying 64 rows * X columns of FP16 values, originally in column major * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads */ -template -__device__ __forceinline__ void CopyFromGlobalToShared(T (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - const T* GlobalPTR, +template +__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, const int GlobalStride, const int NumOfLinesLeft, // To support arbitrary N dimensions. bool Pred = true) { @@ -102,7 +102,7 @@ __device__ __forceinline__ void CopyFromGlobalToShared(T (* __restrict__ SharedP SharedPtrVec[0] = GlobalPtrVec[0]; } #else - cp_async( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); #endif GlobalPTR += NumOfGroups * GlobalStride; SharedPTR += NumOfGroups; From c1dce42044f6324f9248308aa9e52ae597e3443f Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:59:32 +0200 Subject: [PATCH 14/30] Remove templated input type whenever possible 2 --- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 6 +++--- torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 18 +++++++++--------- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 3 +-- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index ef5e2a4c2d..0163bb6081 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -69,7 +69,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned extern __shared__ __align__(128) half smem[]; half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles - __shared__ InputDataType QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes + __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) @@ -127,7 +127,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// const InputDataType* TB_StartGPTR_A_Scale = reinterpret_cast(Scales) + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; const InputDataType* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; - CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); + CopyFromGlobalToShared_Scales(reinterpret_cast(QuantScales+WARP_i*64), WARP_StartGPTR_A_Scales); // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// const InputDataType *BTile_GPTR = reinterpret_cast(B) + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; for(int i=0; i(Scales_RPTR, QuantScales + WARP_i*64); + ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// constexpr bool USE_BF16 = std::is_same::value; initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index 796719f2b9..2c03046805 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -37,23 +37,23 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); #endif int lane_id = threadIdx.x % WARP_SIZE; - half* SPTR_T = reinterpret_cast(SPTR); - const half* GPTR_T = reinterpret_cast(GPTR); - SPTR_T += lane_id*8; - GPTR_T += lane_id*8; + half* SPTR_HALF = reinterpret_cast(SPTR); + const half* GPTR_HALF = reinterpret_cast(GPTR); + SPTR_HALF += lane_id*8; + GPTR_HALF += lane_id*8; #pragma unroll for(int i=0; i(SPTR_T); - const float4* GPTR_VEC = reinterpret_cast(GPTR_T); + float4* SPTR_VEC = reinterpret_cast(SPTR_HALF); + const float4* GPTR_VEC = reinterpret_cast(GPTR_HALF); SPTR_VEC[0] = GPTR_VEC[0]; } #else - cp_async<16>( SPTR_T, GPTR_T, pred_guard); + cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard); #endif - SPTR_T += 256; // Forward 512 Bytes - GPTR_T += 256; // Forward 512 Bytes + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes } } diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 6db5882a53..0b7529b68b 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -163,8 +163,7 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) /* * */ -template -__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, T* WARP_SPTR_Scales) { +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { int lane_id = threadIdx.x % WARP_SIZE; uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); uint32_t tmpReg = SPTR_uint[lane_id]; From 4546c8bb345311d06490686dd9c321d729f98c36 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:33:52 +0200 Subject: [PATCH 15/30] Remove templated input type whenever possible 3 --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 45 +++++++++---------- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 12 ++--- .../csrc/cuda/fp6_llm/kernel_reduction.cuh | 1 - torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 5 +-- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 6 +-- 5 files changed, 33 insertions(+), 36 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 1b801c34a3..286100dc27 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -42,15 +42,15 @@ inline bool isSM75GPU() { } template -static void Kernel_Ex(cudaStream_t stream, - const uint4 *Weight, - const InputDataType *Scales, - const InputDataType *B, - OutputDataType *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - int Split_K) +static void Kernel_Ex(cudaStream_t stream, + const uint4 *Weight, + const half *Scales, + const half *B, + OutputDataType *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) { #ifdef DEBUG_MODE printf("\n"); @@ -71,20 +71,20 @@ static void Kernel_Ex(cudaStream_t stream, printf("\n"); #endif QUANT_GEMM_Kernel<<>> - (Weight, reinterpret_cast(Scales), reinterpret_cast(B), C, M_Global, N_Global, K_Global, Split_K); + (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } template -cudaError_t fpx_linear_kernel(cudaStream_t stream, - const uint4 *Weight, - const InputDataType *Scales, - const InputDataType *B, - InputDataType *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - int Split_K) +cudaError_t fpx_linear_kernel(cudaStream_t stream, + const uint4 *Weight, + const half *Scales, + const half *B, + InputDataType *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + int Split_K) { static_assert(std::is_same::value || std::is_same::value, "Type must be float or __nv_bfloat16"); assert(M_Global % 256 == 0); @@ -205,7 +205,6 @@ torch::Tensor fp_eXmY_linear_forward_cuda( int num_out_channels = _weights.size(0); TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. - TORCH_CHECK(_in_feats.dtype() == _scales.dtype()); // int M = num_out_channels; int K = num_in_channels; @@ -224,8 +223,8 @@ torch::Tensor fp_eXmY_linear_forward_cuda( DISPATCH_HALF_AND_BF16(_in_feats.scalar_type(), "fpx_linear_kernel", [&] { auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto scales = reinterpret_cast(_scales.data_ptr()); + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); // officially supported in Quant-LLM diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 0163bb6081..9a31353fee 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -125,13 +125,13 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; } // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// - const InputDataType* TB_StartGPTR_A_Scale = reinterpret_cast(Scales) + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; - const InputDataType* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; - CopyFromGlobalToShared_Scales(reinterpret_cast(QuantScales+WARP_i*64), WARP_StartGPTR_A_Scales); + const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// - const InputDataType *BTile_GPTR = reinterpret_cast(B) + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; for(int i=0; i (smem_array+i*TilingConfig::TILE_N, reinterpret_cast(BTile_GPTR), K_Global, NumColumnToCopy); + CopyFromGlobalToShared (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); BTile_GPTR += TilingConfig::TILE_K; } // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// @@ -181,7 +181,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory - CopyFromGlobalToShared (write_SPTR, reinterpret_cast(BTile_GPTR), K_Global, NumColumnToCopy, GlobalCopy); + CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); #if __CUDA_ARCH__ >= 800 cp_async_group_commit(); #endif diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index d886533811..0c09c37811 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -47,7 +47,6 @@ template __global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) { - static_assert(std::is_same::value || std::is_same::value, "Type must be float or __nv_bfloat16"); T* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; T* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index 2c03046805..f2af30733f 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -61,9 +61,8 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, /* * Copying 64 Quant Scales (FP16) from global memory to shared memory. */ -template -__device__ __forceinline__ void CopyFromGlobalToShared_Scales(T* SPTR_QuantScales, - const T* GPTR_A_Scales) { +__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, + const half* GPTR_A_Scales) { int lane_id = threadIdx.x % WARP_SIZE; int Offset_Shared = lane_id*2; int Offset_Global = lane_id/4 + (lane_id%4)*16; diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 0b7529b68b..a0f4d4989a 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -36,7 +36,7 @@ __device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, // constexpr int RIGHT_SHIFT = USE_BF16 ? 8 - EXPONENT : 5 - EXPONENT; constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; // NB: arithmetic shift, not logical + constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | MASK3 >> 16; // @@ -77,7 +77,7 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo // Decompose the exponent bias into smaller values and multiply several times. __nv_bfloat16 tmp1 = *BF16_1; __nv_bfloat16 tmp2 = *BF16_2; - // Note that for exponent=3, BIAS_OFFSET = 2^7 - 2^2 = 124 = 4*31 + // Note that for exponent=3, BIAS_OFFSET = 2^7 - 2^2 = 124 = 4*31 // NOTE: only works for exponent=3 right now. const __nv_bfloat16 BIAS = __float2bfloat16(1.0f * (uint32_t(1) << BIAS_OFFSET / 4)); #pragma unroll @@ -150,7 +150,7 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) uint32_t out1, out2; FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); // - *OutputRegs = MultScale(out1, Scale_RPTR[0]); // Muliply FP16/BF16 scales + *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16/BF16 scales OutputRegs += 1; *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16/BF16 scales OutputRegs += 1; From bba42cf949c087fb0954de49d91953097e2a8c96 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:54:39 +0200 Subject: [PATCH 16/30] Less hacky way to construct a float with a large exponent --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 1 - .../cuda/fp6_llm/utils_parallel_dequant.cuh | 33 +++---------------- 2 files changed, 5 insertions(+), 29 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 286100dc27..b67e9f6b76 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -209,7 +209,6 @@ torch::Tensor fp_eXmY_linear_forward_cuda( int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index a0f4d4989a..b164ab60a6 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -69,34 +69,11 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo __nv_bfloat16* BF16_2 = BF16_1 + 1; uint32_t output; __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); - // Multiply with exponent bias here. The exponent bias is 124, which would - // lead to multiplication with 2^124, which would lead to overflow when - // stored in a 32 or 64-bit type. There are two options: (1) decompose the - // exponent bias into smaller values, or (2) type punning (current choice). - if constexpr (false) { // option 1 (decomposition) - // Decompose the exponent bias into smaller values and multiply several times. - __nv_bfloat16 tmp1 = *BF16_1; - __nv_bfloat16 tmp2 = *BF16_2; - // Note that for exponent=3, BIAS_OFFSET = 2^7 - 2^2 = 124 = 4*31 - // NOTE: only works for exponent=3 right now. - const __nv_bfloat16 BIAS = __float2bfloat16(1.0f * (uint32_t(1) << BIAS_OFFSET / 4)); - #pragma unroll - for (int i = 0; i < 4; i++) { - tmp1 = __hmul(tmp1, BIAS); - tmp2 = __hmul(tmp2, BIAS); - } - output_bf16_ptr[0] = __hmul( tmp1, Scale); - output_bf16_ptr[1] = __hmul( tmp2, Scale); - } else { // option 2 (type punning) - // Use type punning to directly construct a float with a large exponent. - union { - uint32_t u32; - float f; - } tmp; - tmp.u32 = (BIAS_OFFSET + 127) << 23; // 127=exponent bias, 23=mantissa - output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(tmp.f)), Scale); - output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(tmp.f)), Scale); - } + // Directly construct a float from the exponent because + // `2^{BIAS_OFFSET} = 2^{124}` (for FP6) is too large to store in an integer. + const float bias = ldexpf(1.0f, BIAS_OFFSET); + output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(bias)), Scale); + output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(bias)), Scale); return output; } From e66395ef0953b5c36f59dd0071aabb4db74a7616 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:23:43 +0200 Subject: [PATCH 17/30] rtol=1e-2 instead of 1e-3 for bfloat16 test --- test/test_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index c174cec794..7802fdeaeb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -74,7 +74,8 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dt error = (results_floatx - results_fp16).abs().mean() gt = results_fp16.abs().mean() relative_error = error / gt - assert relative_error < 1e-3 + rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 + assert relative_error < rtol instantiate_parametrized_tests(TestOps) From 7e9350ec357ce5fa28aae707c8307e4540cfb616 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:12:21 +0200 Subject: [PATCH 18/30] Guards for SM75 --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 21 ++++++++++++++++--- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 6 ++++++ .../csrc/cuda/fp6_llm/kernel_reduction.cuh | 9 ++++++-- torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh | 3 --- torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 3 --- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 11 +++++++--- 6 files changed, 39 insertions(+), 14 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index b67e9f6b76..3b86ed2d22 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -49,8 +49,8 @@ static void Kernel_Ex(cudaStream_t stream, OutputDataType *C, const size_t M_Global, const size_t N_Global, - const size_t K_Global, - int Split_K) + const size_t K_Global, + int Split_K) { #ifdef DEBUG_MODE printf("\n"); @@ -86,7 +86,9 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) int Split_K) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 static_assert(std::is_same::value || std::is_same::value, "Type must be float or __nv_bfloat16"); + #endif assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); assert(N_Global>0); @@ -155,7 +157,19 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, #include // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if __CUDA_ARCH__ == 750 +#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using torch_t = at::Half; \ + using nv_t = half; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } +#else #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Half: { \ @@ -173,6 +187,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +#endif namespace torchao { // MODIFICATION NOTE: dtype of _weights is changed to uint8 diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 9a31353fee..c555be6910 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -153,7 +153,11 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// + #if __CUDA_ARCH__ >= 800 constexpr bool USE_BF16 = std::is_same::value; + #else + constexpr bool USE_BF16 = false; + #endif initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) @@ -215,8 +219,10 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, { if constexpr (std::is_same::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); + #if __CUDA_ARCH__ >= 800 else if constexpr (std::is_same::value) BlockGlobalPTR[j+i*M_Global] = __float2bfloat16_rn(smem_CFrag[i][j]); + #endif else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index 0c09c37811..d7a4c4224b 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -36,9 +36,9 @@ #include #include -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #include -// #endif +#endif #include #define REDUCTION_ELEMENT_PER_THREADBLOCK 256 @@ -62,6 +62,10 @@ __global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Glob THREAD_GPTR_R += M_Global * N_Global; } // Writing to global memory + #if __CUDA_ARCH__ < 800 + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); + #else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); @@ -69,4 +73,5 @@ __global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Glob #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2bfloat16_rn(Results[i]); } + #endif } diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh index 736cbdd5c0..c1d064f32a 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh @@ -33,9 +33,6 @@ #include #include -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#include -// #endif #include template diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index 2c6c4e43b9..57dd8cb53f 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -35,9 +35,6 @@ #include #include -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#include -// #endif #include #include diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index b164ab60a6..e780f2b434 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -20,10 +20,9 @@ #include #include -// TODO: can cuda_bf16 be imported for SM75? How to guard against this? The guard below does not work outside of device code -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #include -// #endif +#endif #include /* @@ -62,6 +61,7 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scal return output; } +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) { constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1)); @@ -76,6 +76,7 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(bias)), Scale); return output; } +#endif // MODIFICATION NOTE: to support MSVC // - u_int32_t __restrict__ Reg[][4] is changed to below. @@ -96,7 +97,11 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) uint32_t *Frag_PTR_1bit = read_RPTR_1bit; uint32_t *Frag_PTR_2bit = read_RPTR_2bit; uint32_t *Frag_PTR_4bit = read_RPTR_4bit; + #if __CUDA_ARCH__ >= 800 using scalar_t = typename std::conditional::type; + #else + using scalar_t = half; + #endif scalar_t *Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) From 401559f3a6a006639293d08a896bb03cdec05531 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:04:07 +0100 Subject: [PATCH 19/30] Remove redundant `__CUDA_ARCH` guards in host code Any check for `__CUDA_ARCH__` in `fp6_linear.cu` will always fail because `__CUDA_ARCH__` is undefined since all of the functions in `fp6_linear.cu` are host functions --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 3b86ed2d22..44510f0fe1 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -18,8 +18,6 @@ // - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory // -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing - #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" @@ -86,9 +84,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) int Split_K) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 - static_assert(std::is_same::value || std::is_same::value, "Type must be float or __nv_bfloat16"); - #endif + static_assert(std::is_same::value || std::is_same::value, "Type must be 'half' or '__nv_bfloat16'"); assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); assert(N_Global>0); @@ -157,19 +153,6 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, #include // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -#if __CUDA_ARCH__ == 750 -#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Half: { \ - using torch_t = at::Half; \ - using nv_t = half; \ - __VA_ARGS__(); \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } -#else #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Half: { \ @@ -187,7 +170,6 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } -#endif namespace torchao { // MODIFICATION NOTE: dtype of _weights is changed to uint8 @@ -271,5 +253,3 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { } } // namespace torchao - -#endif From 5d52e5bf33462c21bfaec4d7ec267fb2b89c3d53 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:04:31 +0100 Subject: [PATCH 20/30] Fix consistency in checking for `CUDA_ARCH` versions --- torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index d7a4c4224b..009084bf10 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -62,7 +62,7 @@ __global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Glob THREAD_GPTR_R += M_Global * N_Global; } // Writing to global memory - #if __CUDA_ARCH__ < 800 + #if __CUDA_ARCH__ == 750 #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); #else From 398da5b6ee60ea45b791eb1a4d0a3a919c5704f7 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:07:24 +0100 Subject: [PATCH 21/30] Update docs --- torchao/dtypes/floatx/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/floatx/README.md b/torchao/dtypes/floatx/README.md index af770cf65c..07066f072d 100644 --- a/torchao/dtypes/floatx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -2,6 +2,8 @@ This is a FP16 x Floatx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to Floatx and integration with torchao API. +This kernel was originally designed for FP16, but was extended to work for BF16 by @tobiasvanderwerff. + ## Usage ```python @@ -11,7 +13,7 @@ from torchao.quantization import ( ) model = ... -model.half() # not necessary, but recommeneded to maintain accuracy +model.half() # not necessary, but recommended to maintain accuracy (bfloat16 cast is also possible) # for generic Floatx EyMz where x = 1 + y + z # fp6 with ebits = 3 and mbits = 2 @@ -40,9 +42,9 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape ``` **NOTE**: -- Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. +- The kernel works for both FP16 and BF16 input activations - Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1. -- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 for some microbenchmark results. +- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 and https://github.com/pytorch/ao/pull/1147 for some microbenchmark results. - FP6 is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details). ## End-to-End benchmarks From d38490fab86039669705153baedbb07784ae73d2 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:54:19 +0100 Subject: [PATCH 22/30] Make float bias a constexpr --- torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index e780f2b434..04f4e2e621 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -61,19 +61,21 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scal return output; } +constexpr float power_of_two(int n) { + return (n == 0) ? 1.0f : 2.0f * power_of_two(n - 1); +} + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) { constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1)); + constexpr float BIAS = power_of_two(BIAS_OFFSET); __nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair); __nv_bfloat16* BF16_2 = BF16_1 + 1; uint32_t output; __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); - // Directly construct a float from the exponent because - // `2^{BIAS_OFFSET} = 2^{124}` (for FP6) is too large to store in an integer. - const float bias = ldexpf(1.0f, BIAS_OFFSET); - output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(bias)), Scale); - output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(bias)), Scale); + output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(BIAS)), Scale); + output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(BIAS)), Scale); return output; } #endif From 11ac84b3fcb11abefb78adc7be7555312b978d0e Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:16:40 +0100 Subject: [PATCH 23/30] Update docs more --- torchao/csrc/cuda/fp6_llm/README.md | 4 ++-- torchao/dtypes/floatx/README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/README.md b/torchao/csrc/cuda/fp6_llm/README.md index ff764cc27d..8df1fb1416 100644 --- a/torchao/csrc/cuda/fp6_llm/README.md +++ b/torchao/csrc/cuda/fp6_llm/README.md @@ -1,7 +1,7 @@ # FP6-LLM kernel -This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN). +This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN). On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. -See https://github.com/pytorch/ao/pull/223 for some benchmark results. +See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results. diff --git a/torchao/dtypes/floatx/README.md b/torchao/dtypes/floatx/README.md index 07066f072d..16aec8362b 100644 --- a/torchao/dtypes/floatx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -13,7 +13,7 @@ from torchao.quantization import ( ) model = ... -model.half() # not necessary, but recommended to maintain accuracy (bfloat16 cast is also possible) +# model can have dtype float16 or bfloat16 # for generic Floatx EyMz where x = 1 + y + z # fp6 with ebits = 3 and mbits = 2 @@ -45,7 +45,7 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape - The kernel works for both FP16 and BF16 input activations - Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1. - On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 and https://github.com/pytorch/ao/pull/1147 for some microbenchmark results. -- FP6 is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details). +- The kernel is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details). ## End-to-End benchmarks From 7bd2833aa92e7790f97427b92d3baeb5982ad8d3 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:49:39 +0100 Subject: [PATCH 24/30] Fix SM75 support --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 44510f0fe1..3aaedc4802 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -18,6 +18,8 @@ // - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory // +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing + #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" @@ -153,6 +155,19 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, #include // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h +#if __CUDA_ARCH__ == 750 +#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using torch_t = at::Half; \ + using nv_t = half; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } +#else #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Half: { \ @@ -170,6 +185,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +#endif namespace torchao { // MODIFICATION NOTE: dtype of _weights is changed to uint8 @@ -253,3 +269,5 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { } } // namespace torchao + +#endif \ No newline at end of file From 69e901daa3429334c50f1f0fa7b05efa883e5ee2 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:44:55 +0100 Subject: [PATCH 25/30] Compile guard for sm<75 --- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index c555be6910..4b5833b91f 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -53,6 +53,9 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + #error GPU not supported, at least Turing generation (sm75) is required + #else #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); @@ -226,4 +229,5 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } + #endif } From 8747d6d38a0c1a41377bee3337b4ed7ce489c794 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:05:14 +0100 Subject: [PATCH 26/30] Check for CUDA synchronous errors after kernel launch If this is not done, the kernel may still run but fail silently, leading to unexpected behavior --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 3aaedc4802..c46bf73276 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -18,7 +18,6 @@ // - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory // -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" @@ -26,6 +25,12 @@ #include #include +#include +#include +#include +#include + + inline bool isSM75GPU() { int device; cudaError_t err = cudaGetDevice(&device); @@ -72,6 +77,7 @@ static void Kernel_Ex(cudaStream_t stream, #endif QUANT_GEMM_Kernel<<>> (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -149,11 +155,6 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, } -#include -#include -#include -#include - // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h #if __CUDA_ARCH__ == 750 #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ @@ -268,6 +269,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::quant_llm_linear", &fp_eXmY_linear_forward_cuda); } -} // namespace torchao - -#endif \ No newline at end of file +} // namespace torchao \ No newline at end of file From 59f5eb79974551e1653681bda995323ce3bc6e2c Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:43:27 +0100 Subject: [PATCH 27/30] Updated compile guard --- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 4b5833b91f..843fc0721e 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -53,9 +53,11 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 - #error GPU not supported, at least Turing generation (sm75) is required - #else + #if __CUDA_ARCH__ < 750 + static_assert(false, "FP6: At least Turing generation (sm75) is required"); + // __trap(); // fails at runtime instead of compile time + #endif + #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); @@ -229,5 +231,4 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } - #endif } From c96cf18e7b22ca211b78944a5af424a01f9bb949 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Fri, 1 Nov 2024 09:33:40 +0100 Subject: [PATCH 28/30] Fix problematic usage of `__CUDA_ARCH__` There are currently several ways of using `__CUDA_ARCH__` that lead to undefined behavior. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch for details of how `__CUDA_ARCH__` should not be used --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 47 +++++++------------ torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 20 ++++---- .../csrc/cuda/fp6_llm/kernel_reduction.cuh | 7 +-- .../cuda/fp6_llm/utils_parallel_dequant.cuh | 8 +--- 4 files changed, 29 insertions(+), 53 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index c46bf73276..a4c24eccb2 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -31,21 +31,6 @@ #include -inline bool isSM75GPU() { - int device; - cudaError_t err = cudaGetDevice(&device); - if (err != cudaSuccess) return false; - - int major, minor; - err = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); - if (err != cudaSuccess) return false; - - err = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); - if (err != cudaSuccess) return false; - - return (major == 7) && (minor == 5); -} - template static void Kernel_Ex(cudaStream_t stream, const uint4 *Weight, @@ -106,7 +91,23 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; - if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { + // Check GPU Compute Capability + int device, major, minor; + cudaError_t err = cudaGetDevice(&device); + if (err != cudaSuccess) return err; + err = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + if (err != cudaSuccess) return err; + err = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); + if (err != cudaSuccess) return err; + + if ((major < 7) || (major == 7 && minor < 5)) { + printf("FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); + return cudaErrorUnknown; + } + + const bool is_sm75_gpu = (major == 7) && (minor == 5); + + if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. if (Split_K == 1) { Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); @@ -156,19 +157,6 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -#if __CUDA_ARCH__ == 750 -#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Half: { \ - using torch_t = at::Half; \ - using nv_t = half; \ - __VA_ARGS__(); \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } -#else #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Half: { \ @@ -186,7 +174,6 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } -#endif namespace torchao { // MODIFICATION NOTE: dtype of _weights is changed to uint8 diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index 843fc0721e..b008971647 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -53,11 +53,10 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { - #if __CUDA_ARCH__ < 750 - static_assert(false, "FP6: At least Turing generation (sm75) is required"); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required."); // __trap(); // fails at runtime instead of compile time #endif - #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); @@ -158,11 +157,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// - #if __CUDA_ARCH__ >= 800 constexpr bool USE_BF16 = std::is_same::value; - #else - constexpr bool USE_BF16 = false; - #endif initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) @@ -222,13 +217,14 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, #pragma unroll for(size_t j=threadIdx.x%WARP_SIZE; j::value) + if constexpr (std::is_same::value) { BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); - #if __CUDA_ARCH__ >= 800 - else if constexpr (std::is_same::value) + } else if constexpr (std::is_same::value) { + #if __CUDA_ARCH__ >= 800 BlockGlobalPTR[j+i*M_Global] = __float2bfloat16_rn(smem_CFrag[i][j]); - #endif - else + #endif + } else { BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; + } } } diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh index 009084bf10..d09d9b861d 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -62,16 +62,13 @@ __global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Glob THREAD_GPTR_R += M_Global * N_Global; } // Writing to global memory - #if __CUDA_ARCH__ == 750 - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); - #else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); } else { // __nv_bfloat16> + #if __CUDA_ARCH__ >= 800 #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2bfloat16_rn(Results[i]); - } #endif + } } diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 04f4e2e621..7fb77f9f8b 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -65,9 +65,9 @@ constexpr float power_of_two(int n) { return (n == 0) ? 1.0f : 2.0f * power_of_two(n - 1); } -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) { +#if __CUDA_ARCH__ >= 800 constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1)); constexpr float BIAS = power_of_two(BIAS_OFFSET); __nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair); @@ -77,8 +77,8 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(BIAS)), Scale); output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(BIAS)), Scale); return output; -} #endif +} // MODIFICATION NOTE: to support MSVC // - u_int32_t __restrict__ Reg[][4] is changed to below. @@ -99,11 +99,7 @@ __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg) uint32_t *Frag_PTR_1bit = read_RPTR_1bit; uint32_t *Frag_PTR_2bit = read_RPTR_2bit; uint32_t *Frag_PTR_4bit = read_RPTR_4bit; - #if __CUDA_ARCH__ >= 800 using scalar_t = typename std::conditional::type; - #else - using scalar_t = half; - #endif scalar_t *Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) From 379bd5ecd8d187011ca78c8de5abc98c0f4db11a Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:18:34 +0100 Subject: [PATCH 29/30] Fix incorrect CUDA error handling --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 38 ++++++++++++++----------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index a4c24eccb2..559e243ca7 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -31,6 +31,19 @@ #include +// https://github.com/Dao-AILab/flash-attention/blob/478ee666cccbd1b8f63648633003059a8dc6827d/hopper/utils.h#L25 +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + template static void Kernel_Ex(cudaStream_t stream, const uint4 *Weight, @@ -62,11 +75,11 @@ static void Kernel_Ex(cudaStream_t stream, #endif QUANT_GEMM_Kernel<<>> (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + CHECK_CUDA_KERNEL_LAUNCH(); } template -cudaError_t fpx_linear_kernel(cudaStream_t stream, +void fpx_linear_kernel(cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, @@ -93,16 +106,12 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, // Check GPU Compute Capability int device, major, minor; - cudaError_t err = cudaGetDevice(&device); - if (err != cudaSuccess) return err; - err = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); - if (err != cudaSuccess) return err; - err = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); - if (err != cudaSuccess) return err; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); if ((major < 7) || (major == 7 && minor < 5)) { - printf("FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); - return cudaErrorUnknown; + TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); } const bool is_sm75_gpu = (major == 7) && (minor == 5); @@ -123,8 +132,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; + TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } @@ -137,8 +145,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; + TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } @@ -150,9 +157,8 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); dim3 BlockDim(WARP_SIZE, 1, 1); SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); + CHECK_CUDA_KERNEL_LAUNCH(); } - - return cudaGetLastError(); } From a6de35a17a4c46b0b254e8a37312d6dd2612b2c3 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:37:29 +0100 Subject: [PATCH 30/30] Make the kernel fail for sm75 + bfloat16 inputs --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 559e243ca7..6141dc3d74 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -109,12 +109,11 @@ void fpx_linear_kernel(cudaStream_t stream, CHECK_CUDA(cudaGetDevice(&device)); CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - - if ((major < 7) || (major == 7 && minor < 5)) { - TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); - } - const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && std::is_same::value) + TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75"); + if ((major < 7) || (major == 7 && minor < 5)) + TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.