diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 0dbee212d6..44f1c89673 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): num_gemms_list = [8] if args.profile: - mkns = [(4096, 4096, 4096)] + mkns = [(4096 * 8, 4096, 4096)] # in profile mode, only run one recipe specified in args.recipe assert args.recipe != "all", ( "In profile mode, only one recipe can be specified, please specify the recipe as" diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 619bf6ca00..9831bbb24d 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -138,6 +138,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits) { + cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index de5a11eb73..079feb4a7d 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -30,6 +30,20 @@ extern "C" { */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] inputs Input tensors with non-swizzled scale_inv. + * \param[in,out] outputs Output tensors which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index cea0e5080b..37d7491d96 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -15,15 +15,17 @@ #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" +namespace transformer_engine { namespace { -constexpr int TB_DIM = 32; -constexpr int NEW_SF_TILE_DIM_K = 16; -constexpr int N_SF_PER_TD_PER_TILE = 4; +constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; +constexpr __device__ __host__ int TB_DIM = 32; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; +constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; // output is in ~K-major interleaved blocks -constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; -constexpr int NEW_SF_TILE_DIM_M_I32 = 32; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; template __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { @@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { } template -__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; @@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons int m_tiles_in_tb = N_TILE_PER_TD; int k_tiles_in_tb = TB_DIM; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; } - if (blockIdx.y == gridDim.y - 1) { + if (bid_y == grid_dim_y - 1) { m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; } - const int32_t* input_i32 = reinterpret_cast(input) + - blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + - blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = + bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + const int32_t* input_i32 = reinterpret_cast(input) + input_offset; int32_t* output_i32[N_TILE_PER_TD]; #pragma unroll for (int i = 0; i < m_tiles_in_tb; i++) { - output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + - (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + output_i32[i] = reinterpret_cast(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 + + (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; } extern __shared__ int slm[]; @@ -90,8 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons threadIdx.y < k_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); + const int thread_offset = + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + // Pad zeros + if (padding_m || padding_k) { + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / M >= original_K || index % M >= original_M) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // local shuffle @@ -126,6 +144,14 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons } } +template +__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} + template __device__ inline void regs_shuffle(LType* regs_vec) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); @@ -143,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) { } template -__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; @@ -154,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons int n_tiles_in_tb = N_TILES_IN_TB; const int K_i32 = K / 4; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - const int* input_i32 = reinterpret_cast(input) + - blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; - int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + - blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; + const int* input_i32 = reinterpret_cast(input) + input_offset; + int* output_i32 = reinterpret_cast(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 + + bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; extern __shared__ int4 slm_v4i[]; @@ -170,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); + const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + if (padding_m || padding_k) { + // Pad zeros + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / K >= original_M || index % K >= original_K) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // shuffle regs @@ -196,9 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons } } -} // namespace +template +__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} -namespace transformer_engine { +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB +struct MultiSwizzleArgs { + // (input) Data buffers for input scaling factors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for swizzled scaling factors + void* output_list[kMaxTensorsPerKernel]; + // Input scaling factor m + int m_list[kMaxTensorsPerKernel]; + // Input scaling factor k + int k_list[kMaxTensorsPerKernel]; + // Input scaling factor m before padding + int original_m_list[kMaxTensorsPerKernel]; + // Input scaling factor k before padding + int original_k_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB); + const int grid_dim_y = num_tiles_m; + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +template +__global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM); + const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD); + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +} // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { @@ -252,27 +383,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_first_dim(); + const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 2: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 1: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -285,27 +418,32 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 2: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 1: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -317,10 +455,212 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s } else { NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); } - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - printf("CUDA Error: %s\n", cudaGetErrorString(err)); - exit(-1); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, + const int vec_load_size, const bool is_rowwise, + cudaStream_t stream) { + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + /* Calculate number of CUDA blocks needed for each tensor. + * We have to do it here because we have to iterate over all tensors in this batch to + * get the minimum vec_load_size. + */ + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + // Launch kernel + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + dim3 block_size(TB_DIM, TB_DIM); + if (is_rowwise) { + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} +void multi_tensor_swizzle_scaling_factors(const std::vector& input, + std::vector& output, cudaStream_t stream) { + auto num_tensors = input.size(); + bool all_has_data = true; + bool all_has_columnwise_data = true; + for (size_t i = 0; i < num_tensors; i++) { + if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); + } + // We don't allow empty tensors. They should be filtered out before calling this function. + if (input[i]->data.numel() == 0) { + NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); + } + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + all_has_data &= input[i]->has_data(); + all_has_columnwise_data &= input[i]->has_columnwise_data(); + } + NVTE_CHECK(all_has_data || all_has_columnwise_data, + "All tensors should have data or columnwise data."); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + if (all_has_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->scale_inv.shape[0]; + const int k = input[i]->scale_inv.shape[1]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK( + m * k == std::accumulate(output[i]->scale_inv.shape.begin(), + output[i]->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + } + + if (all_has_columnwise_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->columnwise_scale_inv.shape[1]; + const int k = input[i]->columnwise_scale_inv.shape[0]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); } } } // namespace transformer_engine @@ -335,3 +675,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud using namespace transformer_engine; swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); } + +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); +} diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index a1899d5b10..ad6cf2a2ee 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -35,6 +35,7 @@ struct MultiPaddingArgs { int padded_num_rows_list[kMaxTensorsPerKernel]; // Input matrix widths int row_length_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each // tensor int block_range[kMaxTensorsPerKernel + 1]; // Number of tensors being processed by kernel diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5408cf1a6b..fe7aecbc22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -398,11 +398,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -441,11 +438,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 99bb4e69fd..4f1ab3e561 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -326,10 +326,8 @@ std::optional> te_general_grouped_gemm( size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, te_pre_gelu_out_vector, te_workspace_vector; - std::vector wrappers; + std::vector te_A_wrappers, te_B_wrappers, wrappers; std::vector D_vectors; - // Keep the swizzled scaling factor tensors alive during the GEMMs. - std::vector> swizzled_scale_inverses_list; auto none = py::none(); @@ -396,10 +394,6 @@ std::optional> te_general_grouped_gemm( continue; } - // Optionally swizzle the scaling factors - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa))); - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb))); - auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); @@ -419,18 +413,25 @@ std::optional> te_general_grouped_gemm( te_bias_vector.emplace_back(te_bias.data()); te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - wrappers.emplace_back(std::move(te_A)); - wrappers.emplace_back(std::move(te_B)); + te_A_wrappers.emplace_back(std::move(te_A)); + te_B_wrappers.emplace_back(std::move(te_B)); wrappers.emplace_back(std::move(te_D)); wrappers.emplace_back(std::move(te_bias)); wrappers.emplace_back(std::move(te_pre_gelu_out)); } + + // Optionally swizzle the scaling factors + // Keep the swizzled scaling factor tensors alive during the GEMMs. + auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); + auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); wrappers.emplace_back(std::move(wsp)); } + // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index f0e0aba00d..fc5f99dcb9 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -841,13 +841,13 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), columnwise_scale_inv_shape.end()); columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } // Convert tensors to Python @@ -939,7 +939,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } } else { // rowwise_usage == false @@ -966,7 +966,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; } } else { // columnwise_usage == false diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index a878345ffc..92f2d3a500 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -75,3 +75,98 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap return swizzled_scale_inv; } + +std::optional multi_tensor_swizzle_scaling_factors( + std::vector& tensors, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (tensors.empty()) { + return std::nullopt; + } + + bool all_same_scaling_mode = std::all_of( + tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) { + return val.scaling_mode() == tensors.front().scaling_mode(); + }); + NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same."); + + if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { + return std::nullopt; + } + + std::vector wrappers; + std::vector input_tensors, output_tensors; + + // Collect scale_inv shapes and calculate buffer size and offsets for scale_invs + std::vector> scale_inv_shapes; + std::vector scale_inv_dptrs; + size_t buffer_size = 0; + std::vector scale_inv_offsets; + constexpr size_t scale_elem_size = 1; + for (auto& tensor : tensors) { + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = tensor.get_rowwise_scale_inv(); + } else { + scale_inv = tensor.get_columnwise_scale_inv(); + } + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_inv_offsets.push_back(buffer_size); + buffer_size += product(scale_inv_shape) * scale_elem_size; + scale_inv_shapes.emplace_back(scale_inv_shape); + scale_inv_dptrs.push_back(scale_inv.data_ptr); + } + + // Allocate full buffer + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + + for (size_t i = 0; i < tensors.size(); ++i) { + auto& tensor = tensors[i]; + void* scale_inv_dptr = scale_inv_dptrs[i]; + void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); + auto input_shape = nvte_shape_to_vector(tensor.shape()); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + } else { + input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_columnwise_data(tensor.columnwise_dptr(), + transformer_engine::DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv( + swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + } + + input_tensors.emplace_back(input_cu.data()); + output_tensors.emplace_back(output_cu.data()); + wrappers.emplace_back(std::move(input_cu)); + wrappers.emplace_back(std::move(output_cu)); + } + + // Launch kernel + nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(), + input_tensors.size(), at::cuda::getCurrentCUDAStream()); + + return buffer; +} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 0cfeb81f59..4b26860967 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -13,11 +13,18 @@ #include "transformer_engine/transformer_engine.h" -/* Swizzle the scaling factor of the input tensor. +/*! \brief Swizzle the scaling factor of the input tensor. * * The returned swizzled scaling factor tensor should be kept alive during the GEMM. */ std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, - bool trans); + bool rowwise); + +/*! \brief Swizzle the scaling factor of the input tensors. + * + * The returned swizzled scaling factor tensors should be kept alive during the GEMMs. + */ +std::optional multi_tensor_swizzle_scaling_factors( + std::vector &inputs, bool rowwise); #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_