Skip to content

Commit 74f74ea

Browse files
authored
Merge pull request #3 from pytorch-tpu/ptxla_training
Add stable diffusion training script for PyTorch/XLA
2 parents 0326225 + 158dd7b commit 74f74ea

File tree

2 files changed

+700
-0
lines changed

2 files changed

+700
-0
lines changed

examples/text_to_image/README_xla.md

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA
2+
3+
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
4+
5+
It has been tested on v4 and v5p TPU versions.
6+
7+
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
8+
where we shard the input batches over the TPU devices.
9+
10+
## Create TPU
11+
12+
To create a TPU on Google Cloud first set these environment variables:
13+
14+
```bash
15+
export TPU_NAME=<tpu-name>
16+
export PROJECT_ID=<project-id>
17+
export ZONE=<google-cloud-zone>
18+
export ACCELERATOR_TYPE=<accelerator type like v5p-8>
19+
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>
20+
```
21+
22+
Then run the create TPU command:
23+
```bash
24+
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
25+
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
26+
--reserved
27+
```
28+
29+
You can also use other ways to reserve TPUs like GKE or queued resources.
30+
31+
## Setup TPU environment
32+
33+
Install PyTorch and PyTorch/XLA nightly versions:
34+
```bash
35+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
36+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
37+
--command='
38+
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
39+
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html'
40+
```
41+
This script has been tested with the above versions but it expected to work with future versions as well.
42+
43+
Verify that PyTorch and PyTorch/XLA were installed correctly:
44+
45+
```bash
46+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
47+
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
48+
--command='python3 -c "import torch; import torch_xla;"'
49+
```
50+
51+
Install this fork of huggingface diffusers repo:
52+
```bash
53+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
54+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
55+
--command='
56+
git clone https://github.com/pytorch-tpu/diffusers.git
57+
cd diffusers
58+
git checkout main
59+
cd examples/text_to_image
60+
pip install -r requirements.txt
61+
cd ../..
62+
sudo pip install -e .'
63+
```
64+
65+
## Run the training job
66+
67+
This script only trains the unet part of the network. The VAE and text encoder
68+
are fixed.
69+
70+
```bash
71+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
72+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
73+
--command='
74+
export XLA_DISABLE_FUNCTIONALIZATION=1
75+
export PROFILE_DIR=/tmp/profile # Update the directory to store profiles if needed.
76+
export CACHE_DIR=/tmp/xla_cache # Update the cache to store compiled XLA graphs if needed.
77+
export DATASET_NAME=lambdalabs/naruto-blip-captions
78+
export PER_HOST_BATCH_SIZE=16 # This is know to work on TPU v4. Can set this to 64 for TPU v5p.
79+
export TRAIN_STEPS=50
80+
export OUTPUT_DIR=/tmp/output/
81+
python diffusers/examples/text_to_image/train_text_to_image_xla.py \
82+
--pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base \
83+
--dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip \
84+
--train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS \
85+
--learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 \
86+
--output_dir=$OUTPUT_DIR --dataloader_num_workers=4 \
87+
--loader_prefetch_size=4 --device_prefetch_size=4'
88+
```
89+
90+
### Environment Envs Explained
91+
92+
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
93+
* `PROFILE_DIR`: Specify where to put the profiling results.
94+
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
95+
* `DATASET_NAME`: Dataset to train the model.
96+
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
97+
* `TRAIN_STEPS`: Total number of training steps to run the training for.
98+
* `OUTPUT_DIR`: Directory to store the fine-tuned model.
99+
100+
## Run inference using the output model
101+
102+
To run inference using the output, you can simply load the model and pass it
103+
input prompts:
104+
105+
```python
106+
import torch
107+
import os
108+
import sys
109+
import numpy as np
110+
111+
import torch_xla.core.xla_model as xm
112+
from time import time
113+
from typing import Tuple
114+
from diffusers import StableDiffusionPipeline
115+
116+
def main(args):
117+
device = xm.xla_device()
118+
model_path = <output_dir>
119+
pipe = StableDiffusionPipeline.from_pretrained(
120+
model_path,
121+
torch_dtype=torch.bfloat16
122+
)
123+
pipe.to(device)
124+
prompt = ["A naruto with green eyes and red legs."]
125+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
126+
image.save("naruto.png")
127+
128+
if __name__ == '__main__':
129+
main()
130+
```

0 commit comments

Comments
 (0)