1
1
import pytest
2
2
import torch
3
-
4
3
from torchao .float8 .float8_utils import compute_error
5
- from torchao .ops import mx_fp8_bf16
4
+ from torchao .ops import mx_fp8_bf16 , mx_fp4_bf16
6
5
from torchao .prototype .mx_formats .mx_tensor import MXTensor
7
6
from torchao .prototype .mx_formats .utils import to_blocked
8
- from torchao .utils import (
9
- TORCH_VERSION_AT_LEAST_2_4 ,
10
- is_sm_at_least_100 ,
11
- )
7
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , is_sm_at_least_100
12
8
13
9
if not TORCH_VERSION_AT_LEAST_2_4 :
14
10
pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
15
11
16
12
17
- def run_matrix_test (M : int , K : int , N : int ) -> float :
18
- """
19
- Run matrix multiplication test with given dimensions.
20
-
21
- Args:
22
- M, K, N: Matrix dimensions
23
-
24
- Returns:
25
- float: SQNR (Signal-to-Quantization-Noise Ratio) value
26
- """
13
+ def run_matrix_test (M : int , K : int , N : int , format : str = "fp8" ) -> float :
27
14
dtype = torch .bfloat16
28
15
device = torch .device ("cuda" )
29
16
30
- # Initialize matrices
31
17
a = torch .rand ((M , K ), dtype = dtype , device = device )
32
18
b = torch .rand ((N , K ), dtype = dtype , device = device )
33
19
34
- # Convert to MX format
35
- a_mx = MXTensor .to_mx (a , torch .float8_e4m3fn , 32 )
36
- b_mx = MXTensor .to_mx (b , torch .float8_e4m3fn , 32 )
20
+ fmt = torch .float8_e4m3fn if format == "fp8" else "fp4_e2m1"
21
+ mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
37
22
38
- a_fp8 = a_mx ._data
39
- b_fp8 = b_mx ._data
40
- assert b_fp8 .is_contiguous ()
41
- b_fp8 = b_fp8 .transpose (- 1 , - 2 )
23
+ a_mx = MXTensor .to_mx (a , fmt , 32 )
24
+ b_mx = MXTensor .to_mx (b , fmt , 32 )
42
25
43
- # Get scales
44
- a_scale_e8 = a_mx ._scale_e8m0 .view (M , K // 32 )
45
- b_scale_e8 = b_mx ._scale_e8m0 .view (N , K // 32 )
26
+ a_data = a_mx ._data
27
+ b_data = b_mx ._data
28
+ assert b_data .is_contiguous ()
29
+ b_data = b_data .transpose (- 1 , - 2 )
46
30
47
- a_scale_block = to_blocked (a_scale_e8 )
48
- b_scale_block = to_blocked (b_scale_e8 )
31
+ a_scale = a_mx ._scale_e8m0 .view (M , K // 32 )
32
+ b_scale = b_mx ._scale_e8m0 .view (N , K // 32 )
33
+
34
+ a_scale_block = to_blocked (a_scale )
35
+ b_scale_block = to_blocked (b_scale )
49
36
50
- # Get reference output
51
37
out_hp = a_mx .to_dtype (torch .bfloat16 ) @ b_mx .to_dtype (torch .bfloat16 ).transpose (
52
38
- 1 , - 2
53
39
)
40
+ out = mx_func (a_data , b_data , a_scale_block , b_scale_block )
54
41
55
- # Run implementation
56
- out_e8_fp8 = mx_fp8_bf16 (a_fp8 , b_fp8 , a_scale_block , b_scale_block )
57
-
58
- # Calculate metrics
59
- sqnr = compute_error (out_hp , out_e8_fp8 )
60
-
61
- return sqnr .item ()
42
+ return compute_error (out_hp , out ).item ()
62
43
63
44
64
45
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -68,35 +49,25 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
68
49
@pytest .mark .parametrize (
69
50
"size" ,
70
51
[
71
- # Small matrices
72
52
(128 , 128 , 128 ),
73
53
(256 , 256 , 256 ),
74
- (384 , 384 , 384 ),
75
- # Medium matrices
54
+ (384 , 384 , 384 ), # Small
76
55
(512 , 512 , 512 ),
77
- (640 , 640 , 640 ),
78
- (768 , 768 , 768 ),
79
- # Large matrices
80
- (896 , 896 , 896 ),
56
+ (768 , 768 , 768 ), # Medium
81
57
(1024 , 1024 , 1024 ),
82
- # Very large matrices
83
- (8192 , 8192 , 8192 ),
84
- # Non-square matrices
58
+ (8192 , 8192 , 8192 ), # Large
85
59
(128 , 256 , 384 ),
86
- (256 , 384 , 512 ),
87
- (384 , 512 , 640 ),
88
- # Non-aligned matrices
60
+ (256 , 384 , 512 ), # Non-square
89
61
(129 , 256 , 384 ),
90
- (256 , 384 , 536 ),
91
- (133 , 512 , 528 ),
62
+ (133 , 512 , 528 ), # Non-aligned
92
63
],
93
64
ids = lambda x : f"{ x [0 ]} x{ x [1 ]} x{ x [2 ]} " ,
94
65
)
95
- def test_matrix_multiplication (size ):
96
- """
97
- Test matrix multiplication with various dimensions.
98
- Verifies that the SQNR meets minimum quality threshold.
99
- """
66
+ @pytest .mark .parametrize ("format" , ["fp8" , "fp4" ])
67
+ def test_matrix_multiplication (size , format ):
100
68
M , K , N = size
101
- sqnr = run_matrix_test (M , K , N )
102
- assert sqnr >= 80.0 , f"SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
69
+ sqnr = run_matrix_test (M , K , N , format )
70
+ threshold = 80.0
71
+ assert (
72
+ sqnr >= threshold
73
+ ), f"{ format } SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
0 commit comments