Skip to content

Enable xpu device #1736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/awq/llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/awq/qwen3_moe_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/big_models_with_sequential_onloading/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ During `oneshot`, only one gpu is required which will be used to onload each lay
```python
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def tokenize(sample):
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/compressed_inference/fp8_compressed_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
compressed_model = AutoModelForCausalLM.from_pretrained(
MODEL_STUB,
torch_dtype="auto",
device_map="cuda:0",
device_map="auto",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will confirm with team that this is what we want here.

)

# tokenize the sample data
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/gemma3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def data_collator(batch):
raw_image = Image.open(requests.get(image_url, stream=True).raw)

# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/idefics3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def tokenize(sample):
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def data_collator(batch):
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/mistral3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def data_collator(batch):
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) # fix dtype
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/mllama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def data_collator(batch):
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/phi3_vision_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def data_collator(batch):
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/pixtral_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def data_collator(batch):
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/qwen2_vl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def data_collator(batch):
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
).to("cuda")
).to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/qwen_2_5_vl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def data_collator(batch):
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
).to("cuda")
).to(model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ oneshot(
Test the quantized model with a sample generation:

```python
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
```
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/gemma2_fp8_kv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def process_and_tokenize(example):
print("\n\n")
dispatch_for_generation(model)
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100, disable_compile=True)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/llama3_fp8_kv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def process_and_tokenize(example):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/phi3.5_fp8_kv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def process_and_tokenize(example):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w4a16/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def tokenize(sample):
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w4a16_fp4/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w4a4_fp4/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w4a4_fp4/qwen_30b_a3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/fp8_block_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/gemma2_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20, disable_compile=True)
print(tokenizer.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/llama3.2_vision_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/llava1.5_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/qwen2vl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_int8/gemma2_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def tokenize(sample):
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=20, disable_compile=True)
print(tokenizer.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_int8/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def tokenize(sample):
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantizing_moe/mixtral_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def tokenize(sample):
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/quantizing_moe/qwen_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def tokenize(sample):
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================")
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_recipe(fp8_enabled):
# Validate the compressed model
print("\n========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/transform/quip_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/transform/spinquant_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Expand Down
6 changes: 4 additions & 2 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def is_ancestor(module: Module) -> bool:
def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
"""
Dispatch a model for sequential calibration using a sequential pipeline.
The model will be offloaded to the CPU and dispatched to CUDA device if available.
The model will be offloaded to the CPU and dispatched to CUDA/XPU device if available.
Removes any existing hooks.

:param model: model to dispatch
Expand All @@ -527,8 +527,10 @@ def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:

if torch.cuda.is_available():
offloaded_dispatch(model, execution_device=torch.device("cuda:0"))
elif hasattr(torch, "xpu") and torch.xpu.is_available():
offloaded_dispatch(model, execution_device=torch.device("xpu:0"))
else:
logger.warning("CUDA is not available! Compressing model on CPU instead")
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")

return model

Expand Down
10 changes: 8 additions & 2 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def initialize_session(
"pass a yaml file or string to the `recipe` argument."
)

torch.cuda.empty_cache()
if hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

def finalize_session(self):
"""
Expand All @@ -186,7 +189,10 @@ def finalize_session(self):
logger.info("Finalized LLM Compressor session")
model = get_session_model()
self.model = model
torch.cuda.empty_cache()
if hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

def create_optimizer(self):
"""
Expand Down