From 9288321325d87052b72ad6c915068840c145d2a1 Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 23 Jan 2025 11:14:30 -0500 Subject: [PATCH 1/2] add code1x16 kernel --- torchao/csrc/cuda/codebook/codebook.cu | 476 ++++++++++++++++++ torchao/ops.py | 173 ++++++- .../quantization/codebook/codebook_ops.py | 6 +- .../codebook/codebook_quantized_tensor.py | 185 +++++-- 4 files changed, 798 insertions(+), 42 deletions(-) create mode 100644 torchao/csrc/cuda/codebook/codebook.cu diff --git a/torchao/csrc/cuda/codebook/codebook.cu b/torchao/csrc/cuda/codebook/codebook.cu new file mode 100644 index 0000000000..4673a7c580 --- /dev/null +++ b/torchao/csrc/cuda/codebook/codebook.cu @@ -0,0 +1,476 @@ +// modified from https://github.com/Vahe1994/AQLM/tree/ab272bfe09915f84bc4e2439055dd7d0e82e08ca/inference_lib/src/aqlm/inference_kernels +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace torchao { + +template +__global__ void Code1x16MatVec( + const int4* __restrict__ A, + const int4* __restrict__ B, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k +) { + int a_gl_stride = prob_k / 8 / group_size; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + __shared__ int4 sh_b[32 * (group_size + 1)]; + float res = 0; + + int iters = (prob_k / group_size + group_size * 32 - 1) / (group_size * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * group_size; i += blockDim.x) { + if (8 * (b_gl_rd + i) < prob_k) + sh_b[(group_size + 1) * (i / group_size) + i % group_size] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * group_size; + + int b_sh_rd = (group_size + 1) * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t dec[group_size / 2]; + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[(group_size / 8) * enc[i]]) + ); + if constexpr (group_size == 16) { + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[4]), "=r"(dec[5]), "=r"(dec[6]), "=r"(dec[7]) + : "l"((void*) &codebook[(group_size / 8) * enc[i] + 1]) + ); + } + if constexpr (use_bfloat16) { + #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800) + nv_bfloat162* a = reinterpret_cast(&dec); + nv_bfloat162* b = reinterpret_cast(&sh_b[b_sh_rd]); + nv_bfloat162 res2 = {}; + #pragma unroll + for (int j = 0; j < group_size / 2; j++) + res2 = __hfma2(a[j], b[j], res2); + res += __bfloat162float(res2.x) + __bfloat162float(res2.y); + #endif + } else { + half2* a = reinterpret_cast(&dec); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; + #pragma unroll + for (int j = 0; j < group_size / 2; j++) + res2 = __hfma2(a[j], b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + } + b_sh_rd += group_size / 8; + } + a_gl_rd += 32; + } + } + + if (pred) { + #pragma unroll + for (int i = 16; i > 0; i /= 2) + res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) { + if constexpr (use_bfloat16) { + reinterpret_cast<__nv_bfloat16*>(C)[c_gl_wr] = __float2bfloat16(res); + } else { + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } + } + } +} + + +template +__global__ void Code1x16Dequant( + const int4* __restrict__ A, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k +) { + int a_gl_stride = prob_k / 8 / group_size; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + int iters = (prob_k / group_size + group_size * 32 - 1) / (group_size * 32); + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t dec[group_size / 2]; + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[(group_size / 8) * enc[i]]) + ); + if constexpr (group_size == 16) { + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[4]), "=r"(dec[5]), "=r"(dec[6]), "=r"(dec[7]) + : "l"((void*) &codebook[(group_size / 8) * enc[i] + 1]) + ); + } + + C[a_gl_rd * group_size + (group_size / 8) * i] = reinterpret_cast(&dec)[0]; + if constexpr (group_size == 16) { + C[a_gl_rd * group_size + (group_size / 8) * i + 1] = reinterpret_cast(&dec)[1]; + } + } + } + a_gl_rd += 32; + } +} + +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +const int THREAD_M = 16; + +template +void code1x16_matvec_cuda( + const void* __restrict__ A, + const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k +) { + int cc_major; + cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, 0); + if (cc_major < 8 && use_bfloat16) { + throw c10::TypeError( + {__func__, __FILE__, static_cast(__LINE__)}, + c10::str( + "You're trying to run AQLM with bfloat16 on a GPU with low compute capability. Use torch.float16 instead." + ) + ); + } + + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16MatVec<<>>( + (const int4*) A, + (const int4*) B, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k + ); +} + +template void code1x16_matvec_cuda(const void*, const void*, void*, const void*, int, int); +template void code1x16_matvec_cuda(const void*, const void*, void*, const void*, int, int); +template void code1x16_matvec_cuda(const void*, const void*, void*, const void*, int, int); +template void code1x16_matvec_cuda(const void*, const void*, void*, const void*, int, int); + +template +void code1x16_dequant_cuda( + const void* __restrict__ A, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16Dequant<<>>( + (const int4*) A, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k + ); +} + +template void code1x16_dequant_cuda<8>(const void*, void*, const void*, int, int); +template void code1x16_dequant_cuda<16>(const void*, void*, const void*, int, int); + + + + +inline bool check_use_bfloat16(const torch::Tensor& input) { + auto dtype = input.dtype(); + if (dtype == at::kHalf) { + return false; + } else if (dtype == at::kBFloat16) { + return true; + } else { + throw c10::NotImplementedError( + {__func__, __FILE__, static_cast(__LINE__)}, + c10::str( + "AQLM CUDA kernels only support float16 and bfloat16. Got ", + dtype.name(), + ". Please specify the correct `torch_dtype` when loading the model." + ) + ); + } +} + +inline torch::Tensor scale_bias_unflatten_output( + torch::Tensor& flat_output, + const torch::Tensor& scales, + const std::optional& bias, + const c10::IntArrayRef& input_sizes +) { + flat_output *= scales.flatten().unsqueeze(0); + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(flat_output.size(-1)); + auto output = flat_output.reshape(output_sizes).clone(); + return output; +} + +void code1x16_matvec( + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + const torch::Tensor& codebook, + const bool use_bfloat16 +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + + if (codebook.size(3) == 8) { + if (use_bfloat16) { + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k); + } else { + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k); + } + } else if (codebook.size(3) == 16) { + if (use_bfloat16) { + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k); + } else { + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k); + } + } else { + throw c10::NotImplementedError( + {__func__, __FILE__, static_cast(__LINE__)}, + c10::str( + "AQLM CUDA kernels only support codebooks with 8 or 16 features. Got ", + codebook.size(3), + "." + ) + ); + } +} + +torch::Tensor code1x16_matmat( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const std::optional& bias +) { + bool use_bfloat16 = check_use_bfloat16(input); + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty({flat_input.size(0), out_features}, + torch::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code1x16_matvec( + codes.squeeze(2), + input_vec, + output_vec, + codebooks, + use_bfloat16 + ); + } + return scale_bias_unflatten_output( + flat_output, + scales, + bias, + input_sizes + ); +} + +torch::Tensor code1x16_dequant( + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales +) { + check_use_bfloat16(codebooks); + auto in_features = codes.size(1) * codebooks.size(3); + auto out_features = scales.size(0); + + auto weight = torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device()) + ); + if (codebooks.size(3) == 8) { + code1x16_dequant_cuda<8>( + codes.data_ptr(), + weight.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features + ); + } else if (codebooks.size(3) == 16) { + code1x16_dequant_cuda<16>( + codes.data_ptr(), + weight.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features + ); + } else { + throw c10::NotImplementedError( + {__func__, __FILE__, static_cast(__LINE__)}, + c10::str( + "AQLM CUDA kernels only support codebooks with 8 or 16 features. Got ", + codebooks.size(3), + "." + ) + ); + } + weight *= scales.index({"...", 0, 0}); + + return weight; +} + +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) +{ + int4 cumulative_sizes; + auto cumulative_size = &cumulative_sizes.x; + int i = 0; + int last = 0; + assert(codebook_partition_sizes.size(0) <= 4); + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) + { + *cumulative_size = codebook_partition_sizes[i].item() + last; + last = *cumulative_size; + } + // fill in the rest with unreachable. + for (; i < 4; ++i, ++cumulative_size) + { + *cumulative_size = last*10; + } + return cumulative_sizes; +} + +torch::Tensor code1x16_matmat_dequant( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const std::optional& bias +) { + bool use_bfloat16 = check_use_bfloat16(input); + + auto input_sizes = input.sizes(); + auto in_features = codes.size(1) * codebooks.size(3); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + + auto weight = torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device()) + ); + if (codebooks.size(3) == 8) { + code1x16_dequant_cuda<8>( + codes.data_ptr(), + weight.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features + ); + } else if (codebooks.size(3) == 16) { + code1x16_dequant_cuda<16>( + codes.data_ptr(), + weight.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features + ); + } else { + throw c10::NotImplementedError( + {__func__, __FILE__, static_cast(__LINE__)}, + c10::str( + "AQLM CUDA kernels only support codebooks with 8 or 16 features. Got ", + codebooks.size(3), + "." + ) + ); + } + + auto flat_output = at::native::linear(flat_input, weight); + return scale_bias_unflatten_output( + flat_output, + scales, + bias, + input_sizes + ); +} + + + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("code1x16_matmat", &code1x16_matmat); + m.impl("code1x16_matmat_dequant", &code1x16_matmat_dequant); +} + +} // namespace torchao \ No newline at end of file diff --git a/torchao/ops.py b/torchao/ops.py index f4b55c4951..4cfc8edd60 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import Tensor @@ -22,7 +24,12 @@ lib.define( "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) - +lib.define( + "code1x16_matmat(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, Tensor? bias=None) -> Tensor" +) +lib.define( + "code1x16_matmat_dequant(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, Tensor? bias=None) -> Tensor" +) def register_custom_op(name): def decorator(func): @@ -615,3 +622,167 @@ def _( dtype=input_scale.dtype, device=input.device, ) + +def code1x16_matmat( + input: Tensor, + codes: Tensor, + codebooks: Tensor, + scales: Tensor, + bias: Optional[Tensor] = None, +) -> Tensor: + """ + Performs a matrix multiplication using codebooks and codes (quantized weights). + + Args: + input (Tensor): Input tensor of shape `(..., in_features)`. + codes (Tensor): Codes tensor of shape `(num_out_groups=out_features, num_in_groups=in_features//in_group_size, num_codebooks=1)` + codebooks (Tensor): Codebooks tensor of shape `(num_codebooks=1, codebook_size=2**16, out_group_size=1, in_group_size)` + scales (Tensor): Scales tensor of shape `(num_out_groups, 1, 1, 1)` + bias (Optional[Tensor]): Optional bias tensor of shape `[out_features]`. + + Returns: + Tensor: Output tensor after the matrix multiplication. + """ + return torch.ops.torchao.code1x16_matmat.default(input, codes, codebooks, scales, bias) + +@register_custom_op("torchao::code1x16_matmat") +def _( + input: Tensor, + codes: Tensor, + codebooks: Tensor, + scales: Tensor, + bias: Optional[Tensor] = None, +) -> Tensor: + + num_out_groups, num_in_groups, num_codebooks = codes.shape + num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + + torch._check( + input.is_cuda, + lambda: "input is not on GPU" + ) + torch._check( + codebooks.is_cuda, + lambda: "codebooks is not on GPU" + ) + torch._check( + num_codebooks == 1, + lambda: f"num_codebooks must equat 1, got {num_codebooks}" + ) + torch._check( + codebook_size == 65536, + lambda: f"codebook_size must equal 65536, got {codebook_size}" + ) + torch._check( + out_group_size == 1, + lambda: f"out_group_size must equal 1, got {out_group_size}" + ) + torch._check( + in_group_size in [8, 16], + lambda: f"in_group_size must equal 8 or 16, got {in_group_size}" + ) + + # Validate dimensions + input_features = input.size(-1) + in_features = num_in_groups * in_group_size + out_features = num_out_groups + + torch._check( + input_features == in_features, + lambda: f"Input features ({input_features}) do not match the expected size ({in_features})." + ) + torch._check( + scales.size(0) == out_features, + lambda: f"Scales tensor size ({scales.size(0)}) does not match the number of output features ({out_features})." + ) + if bias is not None: + torch._check( + bias.size(0) == out_features, + lambda: f"Bias tensor size ({bias.size(0)}) does not match the number of output features ({out_features})." + ) + + # Compute output shape + output_shape = input.shape[:-1] + (out_features,) + return input.new_empty(output_shape) + +def code1x16_matmat_dequant( + input: Tensor, + codes: Tensor, + codebooks: Tensor, + scales: Tensor, + bias: Optional[Tensor] = None, +) -> Tensor: + """ + Dequantizes and performs a matrix multiplication using codebooks and codes. + + Args: + input (Tensor): Input tensor of shape `(..., in_features)`. + codes (Tensor): Codes tensor of shape `(num_out_groups=out_features, num_in_groups=in_features//in_group_size)` + codebooks (Tensor): Codebooks tensor of shape `(num_codebooks=1, codebook_size=2**16, out_group_size=1, in_group_size)` + scales (Tensor): Scales tensor of shape `(num_out_groups, 1, 1, 1)` + bias (Optional[Tensor]): Optional bias tensor of shape `[out_features]`. + + Returns: + Tensor: Output tensor after dequantization and matrix multiplication. + """ + return torch.ops.torchao.code1x16_matmat_dequant.default(input, codes, codebooks, scales, bias) + +@register_custom_op("torchao::code1x16_matmat_dequant") +def _( + input: Tensor, + codes: Tensor, + codebooks: Tensor, + scales: Tensor, + bias: Optional[Tensor] = None, +) -> Tensor: + + num_out_groups, num_in_groups, num_codebooks = codes.shape + num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + + torch._check( + input.is_cuda, + lambda: "input is not on GPU" + ) + torch._check( + codebooks.is_cuda, + lambda: "codebooks is not on GPU" + ) + torch._check( + num_codebooks == 1, + lambda: f"num_codebooks must equat 1, got {num_codebooks}" + ) + torch._check( + codebook_size == 65536, + lambda: f"codebook_size must equal 65536, got {codebook_size}" + ) + torch._check( + out_group_size == 1, + lambda: f"out_group_size must equal 1, got {out_group_size}" + ) + torch._check( + in_group_size in [8, 16], + lambda: f"in_group_size must equal 8 or 16, got {in_group_size}" + ) + + # Validate dimensions + input_features = input.size(-1) + in_features = num_in_groups * in_group_size + out_features = num_out_groups + + torch._check( + input_features == in_features, + lambda: f"Input features ({input_features}) do not match the expected size ({in_features})." + ) + torch._check( + scales.size(0) == out_features, + lambda: f"Scales tensor size ({scales.size(0)}) does not match the number of output features ({out_features})." + ) + if bias is not None: + torch._check( + bias.size(0) == out_features, + lambda: f"Bias tensor size ({bias.size(0)}) does not match the number of output features ({out_features})." + ) + + # Compute output shape + output_shape = input.shape[:-1] + (out_features,) + return input.new_empty(output_shape) \ No newline at end of file diff --git a/torchao/prototype/quantization/codebook/codebook_ops.py b/torchao/prototype/quantization/codebook/codebook_ops.py index 4c0b371c69..a6a9314bbe 100644 --- a/torchao/prototype/quantization/codebook/codebook_ops.py +++ b/torchao/prototype/quantization/codebook/codebook_ops.py @@ -152,7 +152,7 @@ def choose_qparams_codebook( code_dtype: torch.dtype, max_iter: int = 200, devices: Optional[List[torch.device]] = None, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Initialize the codebook using k-means clustering on blocks of the input tensor. @@ -165,7 +165,9 @@ def choose_qparams_codebook( devices (List[torch.device]): Devices to run k-means on. Returns: - torch.Tensor: The codebook tensor, shape (codebook_size, *block_size). + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - codebook (torch.Tensor): Shape (codebook_size, *block_size). + - scales (torch.Tensor): Shape corresponding to scale blocks. """ if code_dtype == torch.int32: codebook_size = 2**16 diff --git a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py index b7e395b434..a3a1f57aa5 100644 --- a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py +++ b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py @@ -28,6 +28,7 @@ class CodebookQuantizedTensor(TorchAOBaseTensor): corresponds to a block in the original tensor. Shape is `(codebook_size, out_block_size, in_block_size)`. block_size (Tuple[int, ...]): Granularity of quantization, specifying the dimensions of tensor blocks that share the same quantization parameters. + scales (torch.Tensor): Scaling factors for each scale block. shape (torch.Size): Shape of the original high-precision tensor. dtype (torch.dtype): dtype of the original high-precision tensor. """ @@ -138,6 +139,7 @@ def from_float( input_tensor (torch.Tensor): The input floating-point tensor to quantize. block_size (Tuple[int, ...]): The size of the blocks for which codes are assigned. code_dtype (torch.dtype): The dtype of the codes. + scale_block_size (int): The size of the blocks that share a scale. chunk_size (int): The chunk size to use during quantization (to control memory usage). """ @@ -187,24 +189,6 @@ def _apply_fn_to_data(self, fn): dtype=self.dtype, ) - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - if func in CODEBOOK_TORCH_FUNCTIONS: - return CODEBOOK_TORCH_FUNCTIONS[func](*args, **kwargs) - - if any(isinstance(arg, cls) for arg in args): - # Dequantize all instances of CodebookQuantizedTensor in args - new_args = tuple( - arg.dequantize() if isinstance(arg, cls) else arg for arg in args - ) - - return func(*new_args, **kwargs) - else: - return NotImplemented - def detach(self): """ Returns a new `CodebookQuantizedTensor`. @@ -233,27 +217,6 @@ def dtype(self): return self._dtype -CODEBOOK_TORCH_FUNCTIONS = {} - - -def implements_torch_function(torch_function): - def decorator(func): - CODEBOOK_TORCH_FUNCTIONS[torch_function] = func - return func - - return decorator - - -@implements_torch_function(torch.Tensor.detach) -def function_detach(tensor, *args, **kwargs): - return tensor.detach() - - -@implements_torch_function(torch.Tensor.requires_grad_) -def function_requires_grad_(tensor, *args, **kwargs): - return tensor.requires_grad_(*args, **kwargs) - - def codebook_weight_only( dtype=torch.uint4, block_size: Tuple[int, int] = (1, 1), @@ -286,3 +249,147 @@ def apply_codebook_quantization(weight, scale_block_size): return _get_linear_subclass_inserter( apply_codebook_quantization, scale_block_size=scale_block_size ) + + + +import logging + +from torch.utils._python_dispatch import return_and_correct_aliasing + +logger = logging.getLogger(__name__) + +# aten = torch.ops.aten + +_CODEBOOK_QLINEAR_DISPATCH_TABLE = {} + +def register_codebook_quantized_linear_dispatch(dispatch_condition, impl): + """ + Register a dispatch for codebook-based quantized linear op with a (condition, impl) pair. + Both must accept (input, weight, bias). + """ + _CODEBOOK_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl + +def deregister_codebook_quantized_linear_dispatch(dispatch_condition): + if dispatch_condition in _CODEBOOK_QLINEAR_DISPATCH_TABLE: + del _CODEBOOK_QLINEAR_DISPATCH_TABLE[dispatch_condition] + else: + logger.warning( + f"Attempting to deregister non-existent codebook dispatch condition: {dispatch_condition}" + ) + +class CodebookLinearNotImplementedError(NotImplementedError): + """Thin wrapper around NotImplementedError to make codebook errors more explicit.""" + pass + +@staticmethod +def _codebook_linear_op(input_tensor: torch.Tensor, + weight_tensor: CodebookQuantizedTensor, + bias: torch.Tensor): + """ + Tries each (dispatch_condition, impl) in the codebook quantized linear dispatch table. + Raises if no specialized path is found. + """ + for condition, impl in _CODEBOOK_QLINEAR_DISPATCH_TABLE.items(): + if condition(input_tensor, weight_tensor, bias): + return impl(input_tensor, weight_tensor, bias) + raise CodebookLinearNotImplementedError( + "No specialized codebook dispatch found for quantized linear op." + ) + +# Attach the _codebook_linear_op to the CodebookQuantizedTensor class +CodebookQuantizedTensor._codebook_linear_op = _codebook_linear_op + +def adapt_codebook_1x16(cqt): + """ + Given a CodebookQuantizedTensor `cqt` with block_size=(1, 16), + reshape codebook, codes, scales into the layout needed by AQLM’s 1x16 kernel. + + Returns: + codebooks_aqlm, codes_aqlm, scales_aqlm + """ + # We expect codebook.shape == [codebook_size, 1, 16]. + # AQLM requires shape [num_codebooks=1, codebook_size, out_group_size=1, in_group_size=16]. + codebooks_aqlm = cqt.codebook.unsqueeze(0) #.contiguous() + + # AQLM expects codes.shape == [num_out_groups, num_in_groups, num_codebooks]. + # `cqt.codes` is [out_groups, in_groups], we just add the last dim: + codes_aqlm = cqt.codes.unsqueeze(-1) #.contiguous() + + # AQLM expects scales.shape == [num_out_groups, 1, 1, 1]. + # `cqt.scales` is [num_out_groups, num_scale_groups=1, 1] do: + scales_aqlm = cqt.scales.unsqueeze(-1) #.contiguous() + + return codebooks_aqlm, codes_aqlm, scales_aqlm + +def _linear_aqlm_code1x16_check( + input_tensor: torch.Tensor, + weight_tensor: torch.Tensor, + bias: torch.Tensor +) -> bool: + + # don't need adapt_codebook_1x16 and other reshaping if refactored to follow AQLM data representation + codebook_size, out_group_size, in_group_size = weight_tensor.codebook.shape + num_codebooks = 1 # right now this is hardcoded, won't be if supporting AQLM + + return ( + isinstance(weight_tensor, CodebookQuantizedTensor) + and (weight_tensor.codebook.device.type, num_codebooks, codebook_size, out_group_size) == ( + "cuda", + 1, + 65536, + 1, + ) + and in_group_size in [8, 16] + ) + +def _linear_aqlm_code1x16_impl( + input_tensor: torch.Tensor, + weight_tensor: CodebookQuantizedTensor, + bias: torch.Tensor +) -> torch.Tensor: + """ + Codebook linear implementation using the `code1x16_matmat_dequant` op. + """ + + codebook, codes, scales = adapt_codebook_1x16(weight_tensor) + from torchao.ops import code1x16_matmat, code1x16_matmat_dequant + + return code1x16_matmat( + input=input_tensor, + codes=codes, + codebooks=codebook, + scales=scales, + bias=bias, + ) + +register_codebook_quantized_linear_dispatch( + _linear_aqlm_code1x16_check, + _linear_aqlm_code1x16_impl +) + + +implements = CodebookQuantizedTensor.implements + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_codebook_linear_op`, this allows us to + # make the branches easier to understand in `_codebook_linear_op` + try: + return weight_tensor._codebook_linear_op(input_tensor, weight_tensor, bias) + except CodebookLinearNotImplementedError: + if isinstance(input_tensor, CodebookQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, CodebookQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) \ No newline at end of file From ee09e00498ab372e591f99f4cdd62b3b30605d99 Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 23 Jan 2025 12:36:17 -0500 Subject: [PATCH 2/2] fix lint --- torchao/csrc/cuda/codebook/codebook.cu | 2 +- torchao/ops.py | 73 +++++++++---------- .../codebook/codebook_quantized_tensor.py | 54 ++++++++------ 3 files changed, 67 insertions(+), 62 deletions(-) diff --git a/torchao/csrc/cuda/codebook/codebook.cu b/torchao/csrc/cuda/codebook/codebook.cu index 4673a7c580..34f343ae47 100644 --- a/torchao/csrc/cuda/codebook/codebook.cu +++ b/torchao/csrc/cuda/codebook/codebook.cu @@ -473,4 +473,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("code1x16_matmat_dequant", &code1x16_matmat_dequant); } -} // namespace torchao \ No newline at end of file +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index 4cfc8edd60..a81f8a8358 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -31,6 +31,7 @@ "code1x16_matmat_dequant(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, Tensor? bias=None) -> Tensor" ) + def register_custom_op(name): def decorator(func): if TORCH_VERSION_AT_LEAST_2_4: @@ -623,6 +624,7 @@ def _( device=input.device, ) + def code1x16_matmat( input: Tensor, codes: Tensor, @@ -643,7 +645,10 @@ def code1x16_matmat( Returns: Tensor: Output tensor after the matrix multiplication. """ - return torch.ops.torchao.code1x16_matmat.default(input, codes, codebooks, scales, bias) + return torch.ops.torchao.code1x16_matmat.default( + input, codes, codebooks, scales, bias + ) + @register_custom_op("torchao::code1x16_matmat") def _( @@ -653,33 +658,25 @@ def _( scales: Tensor, bias: Optional[Tensor] = None, ) -> Tensor: - num_out_groups, num_in_groups, num_codebooks = codes.shape num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + torch._check(input.is_cuda, lambda: "input is not on GPU") + torch._check(codebooks.is_cuda, lambda: "codebooks is not on GPU") torch._check( - input.is_cuda, - lambda: "input is not on GPU" + num_codebooks == 1, lambda: f"num_codebooks must equat 1, got {num_codebooks}" ) torch._check( - codebooks.is_cuda, - lambda: "codebooks is not on GPU" - ) - torch._check( - num_codebooks == 1, - lambda: f"num_codebooks must equat 1, got {num_codebooks}" - ) - torch._check( - codebook_size == 65536, - lambda: f"codebook_size must equal 65536, got {codebook_size}" + codebook_size == 65536, + lambda: f"codebook_size must equal 65536, got {codebook_size}", ) torch._check( out_group_size == 1, - lambda: f"out_group_size must equal 1, got {out_group_size}" + lambda: f"out_group_size must equal 1, got {out_group_size}", ) torch._check( in_group_size in [8, 16], - lambda: f"in_group_size must equal 8 or 16, got {in_group_size}" + lambda: f"in_group_size must equal 8 or 16, got {in_group_size}", ) # Validate dimensions @@ -689,22 +686,23 @@ def _( torch._check( input_features == in_features, - lambda: f"Input features ({input_features}) do not match the expected size ({in_features})." + lambda: f"Input features ({input_features}) do not match the expected size ({in_features}).", ) torch._check( scales.size(0) == out_features, - lambda: f"Scales tensor size ({scales.size(0)}) does not match the number of output features ({out_features})." + lambda: f"Scales tensor size ({scales.size(0)}) does not match the number of output features ({out_features}).", ) if bias is not None: torch._check( bias.size(0) == out_features, - lambda: f"Bias tensor size ({bias.size(0)}) does not match the number of output features ({out_features})." + lambda: f"Bias tensor size ({bias.size(0)}) does not match the number of output features ({out_features}).", ) # Compute output shape output_shape = input.shape[:-1] + (out_features,) return input.new_empty(output_shape) + def code1x16_matmat_dequant( input: Tensor, codes: Tensor, @@ -725,7 +723,10 @@ def code1x16_matmat_dequant( Returns: Tensor: Output tensor after dequantization and matrix multiplication. """ - return torch.ops.torchao.code1x16_matmat_dequant.default(input, codes, codebooks, scales, bias) + return torch.ops.torchao.code1x16_matmat_dequant.default( + input, codes, codebooks, scales, bias + ) + @register_custom_op("torchao::code1x16_matmat_dequant") def _( @@ -735,33 +736,25 @@ def _( scales: Tensor, bias: Optional[Tensor] = None, ) -> Tensor: - num_out_groups, num_in_groups, num_codebooks = codes.shape num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + torch._check(input.is_cuda, lambda: "input is not on GPU") + torch._check(codebooks.is_cuda, lambda: "codebooks is not on GPU") torch._check( - input.is_cuda, - lambda: "input is not on GPU" - ) - torch._check( - codebooks.is_cuda, - lambda: "codebooks is not on GPU" + num_codebooks == 1, lambda: f"num_codebooks must equat 1, got {num_codebooks}" ) torch._check( - num_codebooks == 1, - lambda: f"num_codebooks must equat 1, got {num_codebooks}" - ) - torch._check( - codebook_size == 65536, - lambda: f"codebook_size must equal 65536, got {codebook_size}" + codebook_size == 65536, + lambda: f"codebook_size must equal 65536, got {codebook_size}", ) torch._check( out_group_size == 1, - lambda: f"out_group_size must equal 1, got {out_group_size}" + lambda: f"out_group_size must equal 1, got {out_group_size}", ) torch._check( in_group_size in [8, 16], - lambda: f"in_group_size must equal 8 or 16, got {in_group_size}" + lambda: f"in_group_size must equal 8 or 16, got {in_group_size}", ) # Validate dimensions @@ -771,18 +764,18 @@ def _( torch._check( input_features == in_features, - lambda: f"Input features ({input_features}) do not match the expected size ({in_features})." + lambda: f"Input features ({input_features}) do not match the expected size ({in_features}).", ) torch._check( scales.size(0) == out_features, - lambda: f"Scales tensor size ({scales.size(0)}) does not match the number of output features ({out_features})." + lambda: f"Scales tensor size ({scales.size(0)}) does not match the number of output features ({out_features}).", ) if bias is not None: torch._check( bias.size(0) == out_features, - lambda: f"Bias tensor size ({bias.size(0)}) does not match the number of output features ({out_features})." + lambda: f"Bias tensor size ({bias.size(0)}) does not match the number of output features ({out_features}).", ) - + # Compute output shape output_shape = input.shape[:-1] + (out_features,) - return input.new_empty(output_shape) \ No newline at end of file + return input.new_empty(output_shape) diff --git a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py index a3a1f57aa5..cbc5ba5575 100644 --- a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py +++ b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py @@ -251,17 +251,15 @@ def apply_codebook_quantization(weight, scale_block_size): ) - import logging -from torch.utils._python_dispatch import return_and_correct_aliasing - logger = logging.getLogger(__name__) # aten = torch.ops.aten _CODEBOOK_QLINEAR_DISPATCH_TABLE = {} + def register_codebook_quantized_linear_dispatch(dispatch_condition, impl): """ Register a dispatch for codebook-based quantized linear op with a (condition, impl) pair. @@ -269,6 +267,7 @@ def register_codebook_quantized_linear_dispatch(dispatch_condition, impl): """ _CODEBOOK_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl + def deregister_codebook_quantized_linear_dispatch(dispatch_condition): if dispatch_condition in _CODEBOOK_QLINEAR_DISPATCH_TABLE: del _CODEBOOK_QLINEAR_DISPATCH_TABLE[dispatch_condition] @@ -277,14 +276,19 @@ def deregister_codebook_quantized_linear_dispatch(dispatch_condition): f"Attempting to deregister non-existent codebook dispatch condition: {dispatch_condition}" ) + class CodebookLinearNotImplementedError(NotImplementedError): """Thin wrapper around NotImplementedError to make codebook errors more explicit.""" + pass + @staticmethod -def _codebook_linear_op(input_tensor: torch.Tensor, - weight_tensor: CodebookQuantizedTensor, - bias: torch.Tensor): +def _codebook_linear_op( + input_tensor: torch.Tensor, + weight_tensor: CodebookQuantizedTensor, + bias: torch.Tensor, +): """ Tries each (dispatch_condition, impl) in the codebook quantized linear dispatch table. Raises if no specialized path is found. @@ -296,9 +300,11 @@ def _codebook_linear_op(input_tensor: torch.Tensor, "No specialized codebook dispatch found for quantized linear op." ) + # Attach the _codebook_linear_op to the CodebookQuantizedTensor class CodebookQuantizedTensor._codebook_linear_op = _codebook_linear_op + def adapt_codebook_1x16(cqt): """ Given a CodebookQuantizedTensor `cqt` with block_size=(1, 16), @@ -309,50 +315,55 @@ def adapt_codebook_1x16(cqt): """ # We expect codebook.shape == [codebook_size, 1, 16]. # AQLM requires shape [num_codebooks=1, codebook_size, out_group_size=1, in_group_size=16]. - codebooks_aqlm = cqt.codebook.unsqueeze(0) #.contiguous() + codebooks_aqlm = cqt.codebook.unsqueeze(0) # .contiguous() # AQLM expects codes.shape == [num_out_groups, num_in_groups, num_codebooks]. # `cqt.codes` is [out_groups, in_groups], we just add the last dim: - codes_aqlm = cqt.codes.unsqueeze(-1) #.contiguous() + codes_aqlm = cqt.codes.unsqueeze(-1) # .contiguous() # AQLM expects scales.shape == [num_out_groups, 1, 1, 1]. # `cqt.scales` is [num_out_groups, num_scale_groups=1, 1] do: - scales_aqlm = cqt.scales.unsqueeze(-1) #.contiguous() + scales_aqlm = cqt.scales.unsqueeze(-1) # .contiguous() return codebooks_aqlm, codes_aqlm, scales_aqlm + def _linear_aqlm_code1x16_check( - input_tensor: torch.Tensor, - weight_tensor: torch.Tensor, - bias: torch.Tensor + input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor ) -> bool: - # don't need adapt_codebook_1x16 and other reshaping if refactored to follow AQLM data representation codebook_size, out_group_size, in_group_size = weight_tensor.codebook.shape - num_codebooks = 1 # right now this is hardcoded, won't be if supporting AQLM + num_codebooks = 1 # right now this is hardcoded, won't be if supporting AQLM return ( isinstance(weight_tensor, CodebookQuantizedTensor) - and (weight_tensor.codebook.device.type, num_codebooks, codebook_size, out_group_size) == ( + and ( + weight_tensor.codebook.device.type, + num_codebooks, + codebook_size, + out_group_size, + ) + == ( "cuda", 1, 65536, 1, - ) + ) and in_group_size in [8, 16] ) + def _linear_aqlm_code1x16_impl( input_tensor: torch.Tensor, weight_tensor: CodebookQuantizedTensor, - bias: torch.Tensor + bias: torch.Tensor, ) -> torch.Tensor: """ Codebook linear implementation using the `code1x16_matmat_dequant` op. """ codebook, codes, scales = adapt_codebook_1x16(weight_tensor) - from torchao.ops import code1x16_matmat, code1x16_matmat_dequant + from torchao.ops import code1x16_matmat return code1x16_matmat( input=input_tensor, @@ -362,14 +373,15 @@ def _linear_aqlm_code1x16_impl( bias=bias, ) + register_codebook_quantized_linear_dispatch( - _linear_aqlm_code1x16_check, - _linear_aqlm_code1x16_impl + _linear_aqlm_code1x16_check, _linear_aqlm_code1x16_impl ) implements = CodebookQuantizedTensor.implements + @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -392,4 +404,4 @@ def _(func, types, args, kwargs): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, CodebookQuantizedTensor): weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) \ No newline at end of file + return torch.nn.functional.linear(input_tensor, weight_tensor, bias)