@@ -236,7 +236,7 @@ def test_quantization(self):
236
236
("uint7wo" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4219 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
237
237
]
238
238
239
- if TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
239
+ if TorchAoConfig ._is_xpu_or_cuda_capability_atleast_8_9 ():
240
240
QUANTIZATION_TYPES_TO_TEST .extend ([
241
241
("float8wo_e5m2" , np .array ([0.4590 , 0.5273 , 0.5547 , 0.4219 , 0.4375 , 0.6406 , 0.4316 , 0.4512 , 0.5625 ])),
242
242
("float8wo_e4m3" , np .array ([0.4648 , 0.5234 , 0.5547 , 0.4219 , 0.4414 , 0.6406 , 0.4316 , 0.4531 , 0.5625 ])),
@@ -753,7 +753,7 @@ def test_quantization(self):
753
753
("int8dq" , np .array ([0.0546 , 0.0761 , 0.1386 , 0.0488 , 0.0644 , 0.1425 , 0.0605 , 0.0742 , 0.1406 , 0.0625 , 0.0722 , 0.1523 , 0.0625 , 0.0742 , 0.1503 , 0.0605 , 0.3886 , 0.7968 , 0.5507 , 0.4492 , 0.7890 , 0.5351 , 0.4316 , 0.8007 , 0.5390 , 0.4179 , 0.8281 , 0.5820 , 0.4531 , 0.7812 , 0.5703 , 0.4921 ])),
754
754
]
755
755
756
- if TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
756
+ if TorchAoConfig ._is_xpu_or_cuda_capability_atleast_8_9 ():
757
757
QUANTIZATION_TYPES_TO_TEST .extend ([
758
758
("float8wo_e4m3" , np .array ([0.0546 , 0.0722 , 0.1328 , 0.0468 , 0.0585 , 0.1367 , 0.0605 , 0.0703 , 0.1328 , 0.0625 , 0.0703 , 0.1445 , 0.0585 , 0.0703 , 0.1406 , 0.0605 , 0.3496 , 0.7109 , 0.4843 , 0.4042 , 0.7226 , 0.5000 , 0.4160 , 0.7031 , 0.4824 , 0.3886 , 0.6757 , 0.4667 , 0.3710 , 0.6679 , 0.4902 , 0.4238 ])),
759
759
("fp5_e3m1" , np .array ([0.0527 , 0.0762 , 0.1309 , 0.0449 , 0.0645 , 0.1328 , 0.0566 , 0.0723 , 0.125 , 0.0566 , 0.0703 , 0.1328 , 0.0566 , 0.0742 , 0.1348 , 0.0566 , 0.3633 , 0.7617 , 0.5273 , 0.4277 , 0.7891 , 0.5469 , 0.4375 , 0.8008 , 0.5586 , 0.4336 , 0.7383 , 0.5156 , 0.3906 , 0.6992 , 0.5156 , 0.4375 ])),
0 commit comments