|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | +from torch import Tensor |
| 5 | +from torch.library import triton_op, wrap_triton |
| 6 | + |
| 7 | +from torchao.utils import ceil_div |
| 8 | + |
| 9 | + |
| 10 | +def to_blocked_per_group_2d( |
| 11 | + x_scales: Tensor, group_offs: Tensor, Mg: int, K: int, block_size: int = 32 |
| 12 | +) -> Tensor: |
| 13 | + """ |
| 14 | + Convert scales to blocked format for a 2D tensor (input activations / token groups) |
| 15 | +
|
| 16 | + Args: |
| 17 | + x_scales: Tensor with per group scales in blocked format concatenated into one tensor. |
| 18 | + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the Mg dimension. |
| 19 | + Mg: total size of all groups summed together |
| 20 | + K: K dim size |
| 21 | +
|
| 22 | + Returns: |
| 23 | + blocked_scales: Tensor |
| 24 | + start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group. |
| 25 | + """ |
| 26 | + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked |
| 27 | + |
| 28 | + assert x_scales.ndim == 2, "x_scales must be 2D" |
| 29 | + assert block_size == 32, "Only block_size=32 is supported for now" |
| 30 | + blocked_scales_list = [] |
| 31 | + start_row_after_padding_list = [0] |
| 32 | + group_start_idx = 0 |
| 33 | + for i, group_end_idx in enumerate(group_offs.tolist()): |
| 34 | + group_size = group_end_idx - group_start_idx |
| 35 | + prev_start_row_after_padding = start_row_after_padding_list[i] |
| 36 | + if group_size == 0: |
| 37 | + start_row_after_padding_list.append(prev_start_row_after_padding) |
| 38 | + continue |
| 39 | + |
| 40 | + # Convert group scales to blocked format |
| 41 | + group_scales = x_scales[group_start_idx:group_end_idx] |
| 42 | + group_scales_blocked = _to_blocked(group_scales) |
| 43 | + blocked_scales_list.append(group_scales_blocked) |
| 44 | + |
| 45 | + # Calculate the start row after padding |
| 46 | + scaling_groups_per_row = K // block_size |
| 47 | + rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row |
| 48 | + new_start_row = prev_start_row_after_padding + rows_for_group |
| 49 | + start_row_after_padding_list.append(new_start_row) |
| 50 | + |
| 51 | + # Update next group start index |
| 52 | + group_start_idx = group_end_idx |
| 53 | + |
| 54 | + blocked_scales = torch.cat(blocked_scales_list, dim=0).contiguous() |
| 55 | + blocked_scales = blocked_scales.reshape(-1, K // 32) |
| 56 | + start_row_after_padding = torch.tensor( |
| 57 | + start_row_after_padding_list, device=x_scales.device, dtype=torch.int64 |
| 58 | + ) |
| 59 | + return blocked_scales, start_row_after_padding |
| 60 | + |
| 61 | + |
| 62 | +def to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor: |
| 63 | + """ |
| 64 | + Convert scales to blocked format for each group for a 3D tensor (expert weights) |
| 65 | +
|
| 66 | + Args: |
| 67 | + scales: Tensor of shape (E, N, K//block_size) |
| 68 | + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the |
| 69 | + """ |
| 70 | + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked |
| 71 | + |
| 72 | + blocked_scales_list = [] |
| 73 | + num_groups = weight_scales.shape[0] |
| 74 | + for i in range(num_groups): |
| 75 | + group_scales = weight_scales[i] |
| 76 | + group_scales_blocked = _to_blocked(group_scales) |
| 77 | + blocked_scales_list.append(group_scales_blocked) |
| 78 | + weight_scales_blocked = torch.stack(blocked_scales_list, dim=0).contiguous() |
| 79 | + weight_scales_blocked = weight_scales_blocked.reshape(num_groups, -1) |
| 80 | + return weight_scales_blocked |
| 81 | + |
| 82 | + |
| 83 | +def compute_per_group_blocked_scale_offsets(offsets: torch.Tensor): |
| 84 | + """ |
| 85 | + Rounds each integer in a 1D PyTorch tensor up to the nearest multiple of 128. |
| 86 | +
|
| 87 | + Args: |
| 88 | + offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the Mg dimension. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + - group_sizes: A 1D PyTorch tensor of integers representing the size of each group. |
| 92 | + - starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. |
| 93 | + """ |
| 94 | + # Calculate group sizes |
| 95 | + zero = torch.tensor([0], dtype=offsets.dtype, device=offsets.device) |
| 96 | + group_sizes = torch.diff(offsets, prepend=zero).to(torch.int64) |
| 97 | + |
| 98 | + # Round each group size up to the nearest multiple of 128 |
| 99 | + rounded_group_sizes = ceil_div(group_sizes, 128) * 128 |
| 100 | + |
| 101 | + # Calculate the starting row after padding for each group |
| 102 | + starting_row_after_padding = torch.cumsum(rounded_group_sizes, dim=0) |
| 103 | + return group_sizes, starting_row_after_padding |
| 104 | + |
| 105 | + |
| 106 | +@triton_op("torchao::triton_mx_block_rearrange_per_group", mutates_args=()) |
| 107 | +def triton_mx_block_rearrange_per_group( |
| 108 | + scales_tensor: torch.Tensor, |
| 109 | + offsets: torch.Tensor, |
| 110 | +) -> torch.Tensor: |
| 111 | + """ |
| 112 | + Rearranges an E8M0 tensor scale to block-scaled swizzle format. |
| 113 | +
|
| 114 | + This format is suitable for Tmem as described in NVIDIA documentation: |
| 115 | + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout |
| 116 | +
|
| 117 | + Args: |
| 118 | + scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor. |
| 119 | + blocked_scale_group_offsets: Tensor of shape (num_groups,) which contains the pre-computed start index of each group along the M dimension. |
| 120 | +
|
| 121 | + Returns: |
| 122 | + Rearranged tensor in block-scaled swizzle format |
| 123 | + """ |
| 124 | + assert scales_tensor.element_size() == 1, ( |
| 125 | + "Expected element size to be 1 byte (8 bits)" |
| 126 | + ) |
| 127 | + _, output_scales_group_offsets = compute_per_group_blocked_scale_offsets(offsets) |
| 128 | + rows, cols = scales_tensor.shape |
| 129 | + |
| 130 | + # Calculate blocks needed |
| 131 | + num_groups = output_scales_group_offsets.numel() |
| 132 | + padded_rows = output_scales_group_offsets[ |
| 133 | + -1 |
| 134 | + ] # Final offset is the total number of rows in the tensor |
| 135 | + num_col_blocks = ceil_div(cols, 4) |
| 136 | + padded_cols = num_col_blocks * 4 |
| 137 | + out = scales_tensor.new_empty((padded_rows, padded_cols)) |
| 138 | + |
| 139 | + # We probably want handle multiple blocks per tile but for now keep it simple |
| 140 | + BLOCK_ROWS, BLOCK_COLS = 128, 4 |
| 141 | + |
| 142 | + # Output block stride for the rearranged format |
| 143 | + output_stride_per_row_of_blocks = ( |
| 144 | + BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) |
| 145 | + ) |
| 146 | + |
| 147 | + # We parallelize per group and per col block. |
| 148 | + # Rows per group is variable so we just loop through row blocks per group, per col block. |
| 149 | + grid = lambda META: ( |
| 150 | + num_groups, |
| 151 | + num_col_blocks, |
| 152 | + ) |
| 153 | + |
| 154 | + wrap_triton(triton_scale_swizzle_per_group)[grid]( |
| 155 | + # Input scales |
| 156 | + scales_tensor.view(torch.uint8), |
| 157 | + scales_tensor.stride(0), |
| 158 | + scales_tensor.stride(1), |
| 159 | + rows, |
| 160 | + cols, |
| 161 | + num_groups, |
| 162 | + # Original offsets (to read from) |
| 163 | + offsets, |
| 164 | + # Output scales tensor and group offsets after padding (to write to) |
| 165 | + out.view(torch.uint8), |
| 166 | + output_scales_group_offsets, |
| 167 | + output_stride_per_row_of_blocks, |
| 168 | + BLOCK_ROWS=BLOCK_ROWS, |
| 169 | + BLOCK_COLS=BLOCK_COLS, |
| 170 | + ) |
| 171 | + return out |
| 172 | + |
| 173 | + |
| 174 | +@triton.jit |
| 175 | +def triton_scale_swizzle_per_group( |
| 176 | + scales_ptr, # (M, K//block_size) |
| 177 | + scales_stride_dim0, |
| 178 | + scales_stride_dim1, |
| 179 | + scale_rows, |
| 180 | + scale_cols, |
| 181 | + num_groups, |
| 182 | + orig_offsets, # (num_groups,) |
| 183 | + output_scales_ptr, # (rows + num_groups * 128, tl.cdiv(K, 4) * 4) |
| 184 | + output_scales_group_offsets, # (num_groups,) |
| 185 | + output_stride_per_row_of_blocks, |
| 186 | + BLOCK_ROWS: tl.constexpr, |
| 187 | + BLOCK_COLS: tl.constexpr, |
| 188 | +): |
| 189 | + group_pid = tl.program_id(0) |
| 190 | + block_col_pid = tl.program_id(1) |
| 191 | + |
| 192 | + row_offs = tl.arange(0, BLOCK_ROWS)[:, None] |
| 193 | + col_offs = tl.arange(0, BLOCK_COLS)[None, :] |
| 194 | + |
| 195 | + # Row range for this group |
| 196 | + input_start_row = tl.load(orig_offsets + group_pid - 1, mask=group_pid > 0, other=0) |
| 197 | + input_end_row = tl.load( |
| 198 | + orig_offsets + group_pid, mask=group_pid < num_groups, other=0 |
| 199 | + ) |
| 200 | + |
| 201 | + # Base offset in the output scales tensor we will write to |
| 202 | + output_row_start_offset = tl.load( |
| 203 | + output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0 |
| 204 | + ) |
| 205 | + output_row_end_offset = tl.load( |
| 206 | + output_scales_group_offsets + group_pid + 1, |
| 207 | + mask=group_pid < num_groups, |
| 208 | + other=0, |
| 209 | + ) |
| 210 | + |
| 211 | + # For this group and col block, we iterate through blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. |
| 212 | + # We need to track how many row blocks we iterated through. |
| 213 | + block_row_id = 0 |
| 214 | + for row_off in tl.range( |
| 215 | + input_start_row, tl.cdiv(input_end_row, BLOCK_ROWS) * BLOCK_ROWS, BLOCK_ROWS |
| 216 | + ): |
| 217 | + # Read block of input scales |
| 218 | + block_row_offs = row_off + row_offs[:, None] |
| 219 | + block_col_offs = block_col_pid * BLOCK_COLS + col_offs[None, :] |
| 220 | + block_offs = block_row_offs + block_col_offs |
| 221 | + mask = (block_row_offs < input_end_row) & (block_col_offs < scale_cols) |
| 222 | + input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0) |
| 223 | + |
| 224 | + # Calculate destination indices for each row and col |
| 225 | + r_div_32 = scale_rows // 32 |
| 226 | + r_mod_32 = scale_rows % 32 |
| 227 | + |
| 228 | + # Rearrange to (32, 4, 4) then to final (32, 16) coordinates |
| 229 | + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + scale_cols |
| 230 | + |
| 231 | + # Flatten |
| 232 | + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) |
| 233 | + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) |
| 234 | + |
| 235 | + # Calculate block offset using provided output block stride |
| 236 | + output_block_offsets = ( |
| 237 | + output_row_start_offset |
| 238 | + + (block_row_id * output_stride_per_row_of_blocks) |
| 239 | + + (block_col_pid * BLOCK_COLS) |
| 240 | + ) |
| 241 | + |
| 242 | + tl.store( |
| 243 | + output_scales_ptr + output_block_offsets + dest_indices_flat, |
| 244 | + scales_flat, |
| 245 | + mask=output_block_offsets < output_row_end_offset, |
| 246 | + ) |
| 247 | + |
| 248 | + # Update row block id to next block |
| 249 | + block_row_id += 1 |
0 commit comments