|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Triton kernels for scaling high precision tensors to float8. |
| 9 | +""" |
| 10 | + |
| 11 | +from typing import Tuple |
| 12 | + |
| 13 | +import torch |
| 14 | +import triton |
| 15 | +import triton.language as tl |
| 16 | + |
| 17 | +from torchao.prototype.scaled_grouped_mm.utils import _is_column_major |
| 18 | + |
| 19 | +EPS = 1e-12 |
| 20 | + |
| 21 | +FP8_DTYPE_MAP = { |
| 22 | + torch.int8: tl.int8, |
| 23 | + torch.int16: tl.int16, |
| 24 | + torch.int32: tl.int32, |
| 25 | + torch.int64: tl.int64, |
| 26 | + torch.float8_e4m3fn: tl.float8e4nv, |
| 27 | + torch.float8_e5m2: tl.float8e5, |
| 28 | + torch.float16: tl.float16, |
| 29 | + torch.bfloat16: tl.bfloat16, |
| 30 | + torch.float32: tl.float32, |
| 31 | + torch.float64: tl.float64, |
| 32 | +} |
| 33 | + |
| 34 | +kernel_configs_2D = [ |
| 35 | + triton.Config({"BLOCK_SIZE_ROWS": 32, "BLOCK_SIZE_COLS": 32}, num_warps=1), |
| 36 | + triton.Config({"BLOCK_SIZE_ROWS": 64, "BLOCK_SIZE_COLS": 64}, num_warps=8), |
| 37 | + triton.Config({"BLOCK_SIZE_ROWS": 128, "BLOCK_SIZE_COLS": 128}, num_warps=4), |
| 38 | +] |
| 39 | + |
| 40 | + |
| 41 | +def triton_fp8_row_major_jagged_rowwise_scales( |
| 42 | + hp_tensor: torch.Tensor, |
| 43 | + offsets: torch.Tensor, |
| 44 | + output_dtype: torch.dtype = torch.float8_e4m3fn, |
| 45 | + round_scales_to_power_of_2: bool = False, |
| 46 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 47 | + """ |
| 48 | + Converts a high precision tensor to a float8 tensor is row-major memory layout, |
| 49 | + using 'jagged' rowwise scales (i.e., separate scales for each group/subtensor as |
| 50 | + determined by the offsets). |
| 51 | +
|
| 52 | + Args: |
| 53 | + - hp_tensor: 2D high precision tensor to be converted |
| 54 | + - fp8_dtype: desired fp8 dtype |
| 55 | + - offsets: end index for each group/subtensor along dim 1 |
| 56 | + Returns: |
| 57 | + - float8 tensor |
| 58 | + - jagged rowwise scales (i.e., rowwise scales for each group) |
| 59 | + """ |
| 60 | + assert hp_tensor.ndim == 2, "input tensor must be 2D" |
| 61 | + assert hp_tensor.is_contiguous(), "input tensor must be contiguous" |
| 62 | + |
| 63 | + num_elements = hp_tensor.numel() |
| 64 | + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] |
| 65 | + tl_output_dtype = FP8_DTYPE_MAP[output_dtype] |
| 66 | + |
| 67 | + fp8_dtype_min = torch.finfo(output_dtype).min |
| 68 | + fp8_dtype_max = torch.finfo(output_dtype).max |
| 69 | + |
| 70 | + m, k = hp_tensor.shape |
| 71 | + n_groups = offsets.numel() |
| 72 | + |
| 73 | + # perform fp8 conversion |
| 74 | + output_buffer = torch.empty_like( |
| 75 | + hp_tensor, dtype=output_dtype, device=hp_tensor.device |
| 76 | + ) |
| 77 | + scales_buffer = torch.empty( |
| 78 | + (m * n_groups), dtype=torch.float32, device=hp_tensor.device |
| 79 | + ) |
| 80 | + |
| 81 | + # parallelize across rows and groups (offsets) |
| 82 | + grid = lambda meta: ( |
| 83 | + triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]), |
| 84 | + offsets.numel(), |
| 85 | + ) |
| 86 | + _triton_fp8_row_major_jagged_rowwise_scales[grid]( |
| 87 | + hp_tensor, |
| 88 | + offsets, |
| 89 | + output_buffer, |
| 90 | + scales_buffer, |
| 91 | + m, |
| 92 | + k, |
| 93 | + hp_tensor.stride(0), |
| 94 | + hp_tensor.stride(1), |
| 95 | + output_buffer.stride(0), |
| 96 | + output_buffer.stride(1), |
| 97 | + num_elements, |
| 98 | + fp8_dtype_min, |
| 99 | + fp8_dtype_max, |
| 100 | + tl_input_dtype, |
| 101 | + tl_output_dtype, |
| 102 | + round_scales_to_power_of_2, |
| 103 | + EPS=EPS, |
| 104 | + ) |
| 105 | + return output_buffer, scales_buffer |
| 106 | + |
| 107 | + |
| 108 | +@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) |
| 109 | +@triton.jit |
| 110 | +def _triton_fp8_row_major_jagged_rowwise_scales( |
| 111 | + input_ptr, |
| 112 | + offsets_ptr, |
| 113 | + out_ptr, |
| 114 | + scales_ptr, |
| 115 | + M: int, |
| 116 | + K: int, |
| 117 | + stride_input_row: int, |
| 118 | + stride_input_col: int, |
| 119 | + stride_output_row: int, |
| 120 | + stride_output_col: int, |
| 121 | + num_elements: int, |
| 122 | + fp8_dtype_min: tl.constexpr, |
| 123 | + fp8_dtype_max: tl.constexpr, |
| 124 | + input_dtype: tl.constexpr, |
| 125 | + output_dtype: tl.constexpr, |
| 126 | + round_scales_to_power_of_2: tl.constexpr, |
| 127 | + BLOCK_SIZE_ROWS: tl.constexpr, |
| 128 | + BLOCK_SIZE_COLS: tl.constexpr, |
| 129 | + EPS: tl.constexpr, |
| 130 | +): |
| 131 | + # parallel across rows and groups (offsets) |
| 132 | + block_row_id = tl.program_id(axis=0) |
| 133 | + offset_idx = tl.program_id(axis=1) |
| 134 | + |
| 135 | + # determine start and end column idx for this group |
| 136 | + block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS) |
| 137 | + group_col_start_idx = tl.load( |
| 138 | + offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0 |
| 139 | + ) |
| 140 | + group_col_end_idx = tl.load(offsets_ptr + offset_idx) |
| 141 | + |
| 142 | + # compute rowwise amaxes for this group |
| 143 | + amax_buffer = tl.zeros((BLOCK_SIZE_ROWS,), dtype=tl.float64) |
| 144 | + for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS): |
| 145 | + block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS) |
| 146 | + block_offs = ( |
| 147 | + block_row_offs[:, None] * stride_input_row |
| 148 | + + block_col_offs[None, :] * stride_input_col |
| 149 | + ) |
| 150 | + block_mask = (block_row_offs[:, None] < M) & ( |
| 151 | + block_col_offs[None, :] < group_col_end_idx |
| 152 | + ) |
| 153 | + data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to( |
| 154 | + input_dtype |
| 155 | + ) |
| 156 | + amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1)) |
| 157 | + |
| 158 | + # compute rowwise scales for this group. round scales to nearest power of 2. |
| 159 | + scales = (fp8_dtype_max / tl.clamp(amax_buffer, min=EPS, max=float("inf"))).to( |
| 160 | + tl.float32 |
| 161 | + ) |
| 162 | + if round_scales_to_power_of_2: |
| 163 | + scales = tl.exp2(tl.floor(tl.log2(scales))) |
| 164 | + |
| 165 | + # store rowwise scales for each group in contiguous memory: |
| 166 | + # [group0_row0, group_0_row1, ..., group2_row0, group2_row1] |
| 167 | + scales_offs = block_row_offs + (M * offset_idx) |
| 168 | + scales_mask = tl.arange(0, BLOCK_SIZE_ROWS) < M |
| 169 | + tl.store(scales_ptr + scales_offs, scales, mask=scales_mask) |
| 170 | + |
| 171 | + # perform float8 conversion for this group |
| 172 | + for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS): |
| 173 | + block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS) |
| 174 | + block_offs = ( |
| 175 | + block_row_offs[:, None] * stride_input_row |
| 176 | + + block_col_offs[None, :] * stride_input_col |
| 177 | + ) |
| 178 | + block_mask = (block_row_offs[:, None] < M) & (block_col_offs[None, :] < K) |
| 179 | + data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to( |
| 180 | + input_dtype |
| 181 | + ) |
| 182 | + scaled_data = data * scales[:, None] |
| 183 | + fp8_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to( |
| 184 | + output_dtype |
| 185 | + ) |
| 186 | + out_offs = ( |
| 187 | + block_row_offs[:, None] * stride_output_row |
| 188 | + + block_col_offs[None, :] * stride_output_col |
| 189 | + ) |
| 190 | + out_mask = (block_row_offs[:, None] < M) & (block_col_offs[None, :] < K) |
| 191 | + tl.store(out_ptr + out_offs, fp8_data, mask=out_mask) |
| 192 | + |
| 193 | + |
| 194 | +def triton_fp8_col_major_jagged_colwise_scales( |
| 195 | + hp_tensor: torch.Tensor, |
| 196 | + offsets: torch.Tensor, |
| 197 | + output_dtype: torch.dtype = torch.float8_e4m3fn, |
| 198 | + round_scales_to_power_of_2: bool = False, |
| 199 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 200 | + """ |
| 201 | + Converts a high precision tensor to a float8 tensor is row-major memory layout, |
| 202 | + using 'jagged' column-wise scales (i.e., separate scales for each group/subtensor as |
| 203 | + determined by the offsets). |
| 204 | +
|
| 205 | + Args: |
| 206 | + - hp_tensor: 2D high precision tensor to be converted |
| 207 | + - fp8_dtype: desired fp8 dtype |
| 208 | + - offsets: end index for each group/subtensor along dim 0 |
| 209 | + Returns: |
| 210 | + - float8 tensor |
| 211 | + - jagged column-wise scales (i.e., column-wise scales for each group) |
| 212 | + """ |
| 213 | + assert hp_tensor.ndim == 2, "input tensor must be 2D" |
| 214 | + assert _is_column_major(hp_tensor), "input tensor must be column-major" |
| 215 | + |
| 216 | + num_elements = hp_tensor.numel() |
| 217 | + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] |
| 218 | + tl_output_dtype = FP8_DTYPE_MAP[output_dtype] |
| 219 | + |
| 220 | + fp8_dtype_min = torch.finfo(output_dtype).min |
| 221 | + fp8_dtype_max = torch.finfo(output_dtype).max |
| 222 | + |
| 223 | + k, n = hp_tensor.shape |
| 224 | + n_groups = offsets.numel() |
| 225 | + |
| 226 | + # perform fp8 conversion |
| 227 | + output_buffer = torch.empty_like( |
| 228 | + hp_tensor, dtype=output_dtype, device=hp_tensor.device |
| 229 | + ) |
| 230 | + scales_buffer = torch.empty( |
| 231 | + (n * n_groups), dtype=torch.float32, device=hp_tensor.device |
| 232 | + ) |
| 233 | + |
| 234 | + # parallelize across columns and groups (offsets) |
| 235 | + grid = lambda meta: ( |
| 236 | + triton.cdiv(n, meta["BLOCK_SIZE_COLS"]), |
| 237 | + offsets.numel(), |
| 238 | + ) |
| 239 | + _triton_fp8_col_major_jagged_colwise_scales[grid]( |
| 240 | + hp_tensor, |
| 241 | + offsets, |
| 242 | + output_buffer, |
| 243 | + scales_buffer, |
| 244 | + k, |
| 245 | + n, |
| 246 | + hp_tensor.stride(0), |
| 247 | + hp_tensor.stride(1), |
| 248 | + output_buffer.stride(0), |
| 249 | + output_buffer.stride(1), |
| 250 | + num_elements, |
| 251 | + fp8_dtype_min, |
| 252 | + fp8_dtype_max, |
| 253 | + tl_input_dtype, |
| 254 | + tl_output_dtype, |
| 255 | + round_scales_to_power_of_2, |
| 256 | + EPS=EPS, |
| 257 | + ) |
| 258 | + return output_buffer, scales_buffer |
| 259 | + |
| 260 | + |
| 261 | +@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) |
| 262 | +@triton.jit |
| 263 | +def _triton_fp8_col_major_jagged_colwise_scales( |
| 264 | + input_ptr, |
| 265 | + offsets_ptr, |
| 266 | + out_ptr, |
| 267 | + scales_ptr, |
| 268 | + K: int, |
| 269 | + N: int, |
| 270 | + stride_input_row: int, |
| 271 | + stride_input_col: int, |
| 272 | + stride_output_row: int, |
| 273 | + stride_output_col: int, |
| 274 | + num_elements: int, |
| 275 | + fp8_dtype_min: tl.constexpr, |
| 276 | + fp8_dtype_max: tl.constexpr, |
| 277 | + input_dtype: tl.constexpr, |
| 278 | + output_dtype: tl.constexpr, |
| 279 | + round_scales_to_power_of_2: tl.constexpr, |
| 280 | + BLOCK_SIZE_ROWS: tl.constexpr, |
| 281 | + BLOCK_SIZE_COLS: tl.constexpr, |
| 282 | + EPS: tl.constexpr, |
| 283 | +): |
| 284 | + # parallel across columns and groups (offsets) |
| 285 | + block_col_id = tl.program_id(axis=0) |
| 286 | + offset_idx = tl.program_id(axis=1) |
| 287 | + |
| 288 | + # determine start and end row idx for this group |
| 289 | + block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS) |
| 290 | + group_row_start_idx = tl.load( |
| 291 | + offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0 |
| 292 | + ) |
| 293 | + group_row_end_idx = tl.load(offsets_ptr + offset_idx) |
| 294 | + |
| 295 | + # compute colwise amaxes for this group |
| 296 | + amax_buffer = tl.zeros((BLOCK_SIZE_COLS,), dtype=tl.float64) |
| 297 | + for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS): |
| 298 | + block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS) |
| 299 | + block_offs = ( |
| 300 | + block_row_offs[:, None] * stride_input_row |
| 301 | + + block_col_offs[None, :] * stride_input_col |
| 302 | + ) |
| 303 | + block_mask = (block_row_offs[:, None] < group_row_end_idx) & ( |
| 304 | + block_col_offs[None, :] < N |
| 305 | + ) |
| 306 | + data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to( |
| 307 | + input_dtype |
| 308 | + ) |
| 309 | + amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0)) |
| 310 | + |
| 311 | + # compute rowwise scales for this group. |
| 312 | + scales = (fp8_dtype_max / tl.clamp(amax_buffer, min=EPS, max=float("inf"))).to( |
| 313 | + tl.float32 |
| 314 | + ) |
| 315 | + if round_scales_to_power_of_2: |
| 316 | + scales = tl.exp2(tl.floor(tl.log2(scales))) |
| 317 | + |
| 318 | + # store colwise scales for each group in contiguous memory: |
| 319 | + # [group0_col0, group_0_col1, ..., group2_col0, group2_col1] |
| 320 | + # note: input tensor is in col-major memory layout. |
| 321 | + scales_offs = block_col_offs + (N * offset_idx) |
| 322 | + scales_mask = tl.arange(0, BLOCK_SIZE_COLS) < N |
| 323 | + tl.store(scales_ptr + scales_offs, scales, mask=scales_mask) |
| 324 | + |
| 325 | + # perform float8 conversion for this group |
| 326 | + for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS): |
| 327 | + block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS) |
| 328 | + block_offs = ( |
| 329 | + block_row_offs[:, None] * stride_input_row |
| 330 | + + block_col_offs[None, :] * stride_input_col |
| 331 | + ) |
| 332 | + block_mask = (block_row_offs[:, None] < group_row_end_idx) & ( |
| 333 | + block_col_offs[None, :] < N |
| 334 | + ) |
| 335 | + data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to( |
| 336 | + input_dtype |
| 337 | + ) |
| 338 | + scaled_data = data * scales[None, :] |
| 339 | + fp8_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to( |
| 340 | + output_dtype |
| 341 | + ) |
| 342 | + out_offs = ( |
| 343 | + block_row_offs[:, None] * stride_output_row |
| 344 | + + block_col_offs[None, :] * stride_output_col |
| 345 | + ) |
| 346 | + out_mask = (block_row_offs[:, None] < K) & (block_col_offs[None, :] < N) |
| 347 | + tl.store(out_ptr + out_offs, fp8_data, mask=out_mask) |
0 commit comments