Skip to content

Commit 2a3fbff

Browse files
authored
MX Updated to_blocked to not call nn.pad (#1762)
stack-info: PR: #1762, branch: drisspg/stack/38
1 parent d370196 commit 2a3fbff

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

torchao/prototype/mx_formats/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
import torch.nn.functional as F
98

109
Tensor = torch.Tensor
1110

@@ -31,14 +30,23 @@ def to_blocked(input_matrix) -> Tensor:
3130
n_row_blocks = ceil_div(rows, 128)
3231
n_col_blocks = ceil_div(cols, 4)
3332

34-
# Pad out and view as tiles of (128, 4)
35-
padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128))
36-
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
33+
# Calculate the padded shape
34+
padded_rows = n_row_blocks * 128
35+
padded_cols = n_col_blocks * 4
36+
37+
padded = input_matrix
38+
if (rows, cols) != (padded_rows, padded_cols):
39+
padded = torch.zeros(
40+
(padded_rows, padded_cols),
41+
device=input_matrix.device,
42+
dtype=input_matrix.dtype,
43+
)
44+
padded[:rows, :cols] = input_matrix
3745

38-
# rearrange all tiles
46+
# Rearrange the blocks
47+
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
3948
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
4049

41-
# Layout rearranged tiles according to second pic
4250
return rearranged.flatten()
4351

4452

0 commit comments

Comments
 (0)