|
15 | 15 | from torchao.prototype.moe_training.kernels import (
|
16 | 16 | triton_fp8_col_major_jagged_colwise_scales,
|
17 | 17 | triton_fp8_row_major_jagged_rowwise_scales,
|
| 18 | + triton_fp8_rowwise_3d_transpose_rhs, |
18 | 19 | )
|
19 | 20 | from torchao.prototype.moe_training.utils import (
|
20 | 21 | _is_column_major,
|
@@ -142,20 +143,11 @@ def forward(
|
142 | 143 | # Precompute non-transposed B column-major for backward, to save memory by storing the
|
143 | 144 | # low precision B tensor instead of the high precision B tensor.
|
144 | 145 | # In the backward this is needed for grad_A: grad_output @ B.
|
145 |
| - B = B_t.contiguous().transpose(-2, -1) |
146 |
| - |
147 |
| - # - B shape: (E, N, K) |
148 |
| - # - B scales must be computed rowwise keeping the outer/final dim, so: |
149 |
| - # - B_scale shape: (E, 1, K) |
150 |
| - B_scales = tensor_to_scale( |
151 |
| - B, |
152 |
| - torch.float8_e4m3fn, |
153 |
| - scaling_granularity=ScalingGranularity.AXISWISE, |
154 |
| - axiswise_dim=-2, |
| 146 | + B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( |
| 147 | + B_t, |
| 148 | + output_dtype=torch.float8_e4m3fn, |
155 | 149 | round_scales_to_power_of_2=True,
|
156 | 150 | )
|
157 |
| - B_scaled = B.to(torch.float32) * B_scales |
158 |
| - B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) |
159 | 151 |
|
160 | 152 | # Store what we need for backward.
|
161 | 153 | ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
|
|
0 commit comments