5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import torch
8
- import torch .nn .functional as F
9
8
10
9
Tensor = torch .Tensor
11
10
@@ -31,14 +30,23 @@ def to_blocked(input_matrix) -> Tensor:
31
30
n_row_blocks = ceil_div (rows , 128 )
32
31
n_col_blocks = ceil_div (cols , 4 )
33
32
34
- # Pad out and view as tiles of (128, 4)
35
- padded = F .pad (input_matrix , (0 , - cols % 4 , 0 , - rows % 128 ))
36
- blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
33
+ # Calculate the padded shape
34
+ padded_rows = n_row_blocks * 128
35
+ padded_cols = n_col_blocks * 4
36
+
37
+ padded = input_matrix
38
+ if (rows , cols ) != (padded_rows , padded_cols ):
39
+ padded = torch .zeros (
40
+ (padded_rows , padded_cols ),
41
+ device = input_matrix .device ,
42
+ dtype = input_matrix .dtype ,
43
+ )
44
+ padded [:rows , :cols ] = input_matrix
37
45
38
- # rearrange all tiles
46
+ # Rearrange the blocks
47
+ blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
39
48
rearranged = blocks .reshape (- 1 , 4 , 32 , 4 ).transpose (1 , 2 ).reshape (- 1 , 32 , 16 )
40
49
41
- # Layout rearranged tiles according to second pic
42
50
return rearranged .flatten ()
43
51
44
52
0 commit comments