Skip to content

Commit 2841504

Browse files
authored
enable quantcompile test on xpu (#11988)
Signed-off-by: Yao, Matrix <[email protected]>
1 parent 3d2f8ae commit 2841504

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

tests/quantization/test_torch_compile_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import torch
1919

2020
from diffusers import DiffusionPipeline
21-
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
21+
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device
2222

2323

24-
@require_torch_gpu
24+
@require_torch_accelerator
2525
@slow
2626
class QuantCompileTests:
2727
@property
@@ -51,7 +51,7 @@ def _init_pipeline(self, quantization_config, torch_dtype):
5151
return pipe
5252

5353
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
54-
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
54+
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to(torch_device)
5555
# `fullgraph=True` ensures no graph breaks
5656
pipe.transformer.compile(fullgraph=True)
5757

@@ -71,7 +71,7 @@ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16
7171

7272
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
7373
group_offload_kwargs = {
74-
"onload_device": torch.device("cuda"),
74+
"onload_device": torch.device(torch_device),
7575
"offload_device": torch.device("cpu"),
7676
"offload_type": "leaf_level",
7777
"use_stream": use_stream,
@@ -81,7 +81,7 @@ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16
8181
for name, component in pipe.components.items():
8282
if name != "transformer" and isinstance(component, torch.nn.Module):
8383
if torch.device(component.device).type == "cpu":
84-
component.to("cuda")
84+
component.to(torch_device)
8585

8686
# small resolutions to ensure speedy execution.
8787
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)

tests/quantization/torchao/test_torchao.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_quantization(self):
236236
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
237237
]
238238

239-
if TorchAoConfig._is_cuda_capability_atleast_8_9():
239+
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
240240
QUANTIZATION_TYPES_TO_TEST.extend([
241241
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
242242
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
@@ -753,7 +753,7 @@ def test_quantization(self):
753753
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
754754
]
755755

756-
if TorchAoConfig._is_cuda_capability_atleast_8_9():
756+
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
757757
QUANTIZATION_TYPES_TO_TEST.extend([
758758
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
759759
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),

0 commit comments

Comments
 (0)