Skip to content

Commit f4c1c51

Browse files
add triton kernels for float8 quantization with jagged rowwise scales
1 parent f788897 commit f4c1c51

File tree

6 files changed

+470
-31
lines changed

6 files changed

+470
-31
lines changed

torchao/prototype/scaled_grouped_mm/kernels/__init__.py

Whitespace-only changes.
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Triton kernels for scaling high precision tensors to float8.
9+
"""
10+
11+
from typing import Tuple
12+
13+
import torch
14+
import triton
15+
import triton.language as tl
16+
17+
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
18+
19+
EPS = 1e-12
20+
21+
FP8_DTYPE_MAP = {
22+
torch.int8: tl.int8,
23+
torch.int16: tl.int16,
24+
torch.int32: tl.int32,
25+
torch.int64: tl.int64,
26+
torch.float8_e4m3fn: tl.float8e4nv,
27+
torch.float8_e5m2: tl.float8e5,
28+
torch.float16: tl.float16,
29+
torch.bfloat16: tl.bfloat16,
30+
torch.float32: tl.float32,
31+
torch.float64: tl.float64,
32+
}
33+
34+
kernel_configs_2D = [
35+
triton.Config({"BLOCK_SIZE_ROWS": 32, "BLOCK_SIZE_COLS": 32}, num_warps=1),
36+
triton.Config({"BLOCK_SIZE_ROWS": 64, "BLOCK_SIZE_COLS": 64}, num_warps=8),
37+
triton.Config({"BLOCK_SIZE_ROWS": 128, "BLOCK_SIZE_COLS": 128}, num_warps=4),
38+
]
39+
40+
41+
def triton_fp8_row_major_jagged_rowwise_scales(
42+
hp_tensor: torch.Tensor,
43+
offsets: torch.Tensor,
44+
output_dtype: torch.dtype = torch.float8_e4m3fn,
45+
round_scales_to_power_of_2: bool = False,
46+
) -> Tuple[torch.Tensor, torch.Tensor]:
47+
"""
48+
Converts a high precision tensor to a float8 tensor is row-major memory layout,
49+
using 'jagged' rowwise scales (i.e., separate scales for each group/subtensor as
50+
determined by the offsets).
51+
52+
Args:
53+
- hp_tensor: 2D high precision tensor to be converted
54+
- fp8_dtype: desired fp8 dtype
55+
- offsets: end index for each group/subtensor along dim 1
56+
Returns:
57+
- float8 tensor
58+
- jagged rowwise scales (i.e., rowwise scales for each group)
59+
"""
60+
assert hp_tensor.ndim == 2, "input tensor must be 2D"
61+
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"
62+
63+
num_elements = hp_tensor.numel()
64+
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
65+
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]
66+
67+
fp8_dtype_min = torch.finfo(output_dtype).min
68+
fp8_dtype_max = torch.finfo(output_dtype).max
69+
70+
m, k = hp_tensor.shape
71+
n_groups = offsets.numel()
72+
73+
# perform fp8 conversion
74+
output_buffer = torch.empty_like(
75+
hp_tensor, dtype=output_dtype, device=hp_tensor.device
76+
)
77+
scales_buffer = torch.empty(
78+
(m * n_groups), dtype=torch.float32, device=hp_tensor.device
79+
)
80+
81+
# parallelize across rows and groups (offsets)
82+
grid = lambda meta: (
83+
triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]),
84+
offsets.numel(),
85+
)
86+
_triton_fp8_row_major_jagged_rowwise_scales[grid](
87+
hp_tensor,
88+
offsets,
89+
output_buffer,
90+
scales_buffer,
91+
m,
92+
k,
93+
hp_tensor.stride(0),
94+
hp_tensor.stride(1),
95+
output_buffer.stride(0),
96+
output_buffer.stride(1),
97+
num_elements,
98+
fp8_dtype_min,
99+
fp8_dtype_max,
100+
tl_input_dtype,
101+
tl_output_dtype,
102+
round_scales_to_power_of_2,
103+
EPS=EPS,
104+
)
105+
return output_buffer, scales_buffer
106+
107+
108+
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
109+
@triton.jit
110+
def _triton_fp8_row_major_jagged_rowwise_scales(
111+
input_ptr,
112+
offsets_ptr,
113+
out_ptr,
114+
scales_ptr,
115+
M: int,
116+
K: int,
117+
stride_input_row: int,
118+
stride_input_col: int,
119+
stride_output_row: int,
120+
stride_output_col: int,
121+
num_elements: int,
122+
fp8_dtype_min: tl.constexpr,
123+
fp8_dtype_max: tl.constexpr,
124+
input_dtype: tl.constexpr,
125+
output_dtype: tl.constexpr,
126+
round_scales_to_power_of_2: tl.constexpr,
127+
BLOCK_SIZE_ROWS: tl.constexpr,
128+
BLOCK_SIZE_COLS: tl.constexpr,
129+
EPS: tl.constexpr,
130+
):
131+
# parallel across rows and groups (offsets)
132+
block_row_id = tl.program_id(axis=0)
133+
offset_idx = tl.program_id(axis=1)
134+
135+
# determine start and end column idx for this group
136+
block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS)
137+
group_col_start_idx = tl.load(
138+
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
139+
)
140+
group_col_end_idx = tl.load(offsets_ptr + offset_idx)
141+
142+
# compute rowwise amaxes for this group
143+
amax_buffer = tl.zeros((BLOCK_SIZE_ROWS,), dtype=tl.float64)
144+
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
145+
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
146+
block_offs = (
147+
block_row_offs[:, None] * stride_input_row
148+
+ block_col_offs[None, :] * stride_input_col
149+
)
150+
block_mask = (block_row_offs[:, None] < M) & (
151+
block_col_offs[None, :] < group_col_end_idx
152+
)
153+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
154+
input_dtype
155+
)
156+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1))
157+
158+
# compute rowwise scales for this group. round scales to nearest power of 2.
159+
scales = (fp8_dtype_max / tl.clamp(amax_buffer, min=EPS, max=float("inf"))).to(
160+
tl.float32
161+
)
162+
if round_scales_to_power_of_2:
163+
scales = tl.exp2(tl.floor(tl.log2(scales)))
164+
165+
# store rowwise scales for each group in contiguous memory:
166+
# [group0_row0, group_0_row1, ..., group2_row0, group2_row1]
167+
scales_offs = block_row_offs + (M * offset_idx)
168+
scales_mask = tl.arange(0, BLOCK_SIZE_ROWS) < M
169+
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)
170+
171+
# perform float8 conversion for this group
172+
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
173+
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
174+
block_offs = (
175+
block_row_offs[:, None] * stride_input_row
176+
+ block_col_offs[None, :] * stride_input_col
177+
)
178+
block_mask = (block_row_offs[:, None] < M) & (block_col_offs[None, :] < K)
179+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
180+
input_dtype
181+
)
182+
scaled_data = data * scales[:, None]
183+
fp8_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
184+
output_dtype
185+
)
186+
out_offs = (
187+
block_row_offs[:, None] * stride_output_row
188+
+ block_col_offs[None, :] * stride_output_col
189+
)
190+
out_mask = (block_row_offs[:, None] < M) & (block_col_offs[None, :] < K)
191+
tl.store(out_ptr + out_offs, fp8_data, mask=out_mask)
192+
193+
194+
def triton_fp8_col_major_jagged_colwise_scales(
195+
hp_tensor: torch.Tensor,
196+
offsets: torch.Tensor,
197+
output_dtype: torch.dtype = torch.float8_e4m3fn,
198+
round_scales_to_power_of_2: bool = False,
199+
) -> Tuple[torch.Tensor, torch.Tensor]:
200+
"""
201+
Converts a high precision tensor to a float8 tensor is row-major memory layout,
202+
using 'jagged' column-wise scales (i.e., separate scales for each group/subtensor as
203+
determined by the offsets).
204+
205+
Args:
206+
- hp_tensor: 2D high precision tensor to be converted
207+
- fp8_dtype: desired fp8 dtype
208+
- offsets: end index for each group/subtensor along dim 0
209+
Returns:
210+
- float8 tensor
211+
- jagged column-wise scales (i.e., column-wise scales for each group)
212+
"""
213+
assert hp_tensor.ndim == 2, "input tensor must be 2D"
214+
assert _is_column_major(hp_tensor), "input tensor must be column-major"
215+
216+
num_elements = hp_tensor.numel()
217+
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
218+
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]
219+
220+
fp8_dtype_min = torch.finfo(output_dtype).min
221+
fp8_dtype_max = torch.finfo(output_dtype).max
222+
223+
k, n = hp_tensor.shape
224+
n_groups = offsets.numel()
225+
226+
# perform fp8 conversion
227+
output_buffer = torch.empty_like(
228+
hp_tensor, dtype=output_dtype, device=hp_tensor.device
229+
)
230+
scales_buffer = torch.empty(
231+
(n * n_groups), dtype=torch.float32, device=hp_tensor.device
232+
)
233+
234+
# parallelize across columns and groups (offsets)
235+
grid = lambda meta: (
236+
triton.cdiv(n, meta["BLOCK_SIZE_COLS"]),
237+
offsets.numel(),
238+
)
239+
_triton_fp8_col_major_jagged_colwise_scales[grid](
240+
hp_tensor,
241+
offsets,
242+
output_buffer,
243+
scales_buffer,
244+
k,
245+
n,
246+
hp_tensor.stride(0),
247+
hp_tensor.stride(1),
248+
output_buffer.stride(0),
249+
output_buffer.stride(1),
250+
num_elements,
251+
fp8_dtype_min,
252+
fp8_dtype_max,
253+
tl_input_dtype,
254+
tl_output_dtype,
255+
round_scales_to_power_of_2,
256+
EPS=EPS,
257+
)
258+
return output_buffer, scales_buffer
259+
260+
261+
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
262+
@triton.jit
263+
def _triton_fp8_col_major_jagged_colwise_scales(
264+
input_ptr,
265+
offsets_ptr,
266+
out_ptr,
267+
scales_ptr,
268+
K: int,
269+
N: int,
270+
stride_input_row: int,
271+
stride_input_col: int,
272+
stride_output_row: int,
273+
stride_output_col: int,
274+
num_elements: int,
275+
fp8_dtype_min: tl.constexpr,
276+
fp8_dtype_max: tl.constexpr,
277+
input_dtype: tl.constexpr,
278+
output_dtype: tl.constexpr,
279+
round_scales_to_power_of_2: tl.constexpr,
280+
BLOCK_SIZE_ROWS: tl.constexpr,
281+
BLOCK_SIZE_COLS: tl.constexpr,
282+
EPS: tl.constexpr,
283+
):
284+
# parallel across columns and groups (offsets)
285+
block_col_id = tl.program_id(axis=0)
286+
offset_idx = tl.program_id(axis=1)
287+
288+
# determine start and end row idx for this group
289+
block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS)
290+
group_row_start_idx = tl.load(
291+
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
292+
)
293+
group_row_end_idx = tl.load(offsets_ptr + offset_idx)
294+
295+
# compute colwise amaxes for this group
296+
amax_buffer = tl.zeros((BLOCK_SIZE_COLS,), dtype=tl.float64)
297+
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
298+
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
299+
block_offs = (
300+
block_row_offs[:, None] * stride_input_row
301+
+ block_col_offs[None, :] * stride_input_col
302+
)
303+
block_mask = (block_row_offs[:, None] < group_row_end_idx) & (
304+
block_col_offs[None, :] < N
305+
)
306+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
307+
input_dtype
308+
)
309+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0))
310+
311+
# compute rowwise scales for this group.
312+
scales = (fp8_dtype_max / tl.clamp(amax_buffer, min=EPS, max=float("inf"))).to(
313+
tl.float32
314+
)
315+
if round_scales_to_power_of_2:
316+
scales = tl.exp2(tl.floor(tl.log2(scales)))
317+
318+
# store colwise scales for each group in contiguous memory:
319+
# [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
320+
# note: input tensor is in col-major memory layout.
321+
scales_offs = block_col_offs + (N * offset_idx)
322+
scales_mask = tl.arange(0, BLOCK_SIZE_COLS) < N
323+
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)
324+
325+
# perform float8 conversion for this group
326+
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
327+
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
328+
block_offs = (
329+
block_row_offs[:, None] * stride_input_row
330+
+ block_col_offs[None, :] * stride_input_col
331+
)
332+
block_mask = (block_row_offs[:, None] < group_row_end_idx) & (
333+
block_col_offs[None, :] < N
334+
)
335+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
336+
input_dtype
337+
)
338+
scaled_data = data * scales[None, :]
339+
fp8_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
340+
output_dtype
341+
)
342+
out_offs = (
343+
block_row_offs[:, None] * stride_output_row
344+
+ block_col_offs[None, :] * stride_output_col
345+
)
346+
out_mask = (block_row_offs[:, None] < K) & (block_col_offs[None, :] < N)
347+
tl.store(out_ptr + out_offs, fp8_data, mask=out_mask)

0 commit comments

Comments
 (0)