-
Notifications
You must be signed in to change notification settings - Fork 2
Add stable diffusion training script for PyTorch/XLA #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
7aab092
Add training script for torch_xla
bhavya01 65a80d8
Add README_xla.md
bhavya01 127c10c
Update README_XLA.md to install torch and torch_xla from 09/05
bhavya01 f2d2e18
Remove unnecessary args
bhavya01 72d1298
Remove unnecessary imports
bhavya01 23e2c54
Fix spaces in README_xla.md
bhavya01 158dd7b
print exception while dataloading and update batch size in readme
bhavya01 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA | ||
|
||
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA. | ||
|
||
It has been tested on v4 and v5p TPU versions. | ||
|
||
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler | ||
where we shard the input batches over the TPU devices. | ||
|
||
## Create TPU | ||
|
||
To create a TPU on Google Cloud first set these environment variables: | ||
|
||
```bash | ||
export TPU_NAME=<tpu-name> | ||
export PROJECT_ID=<project-id> | ||
export ZONE=<google-cloud-zone> | ||
export ACCELERATOR_TYPE=<accelerator type like v5p-8> | ||
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p> | ||
``` | ||
|
||
Then run the create TPU command: | ||
```bash | ||
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID} | ||
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION} | ||
--reserved | ||
``` | ||
|
||
You can also use other ways to reserve TPUs like GKE or queued resources. | ||
|
||
## Setup TPU environment | ||
|
||
Install PyTorch and PyTorch/XLA nightly versions: | ||
```bash | ||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ | ||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \ | ||
--command=' | ||
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu | ||
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html' | ||
``` | ||
This script has been tested with the above versions but it expected to work with future versions as well. | ||
|
||
Verify that PyTorch and PyTorch/XLA were installed correctly: | ||
|
||
```bash | ||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ | ||
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \ | ||
--command='python3 -c "import torch; import torch_xla;"' | ||
``` | ||
|
||
Install this fork of huggingface diffusers repo: | ||
```bash | ||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ | ||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \ | ||
--command=' | ||
git clone https://github.com/pytorch-tpu/diffusers.git | ||
cd diffusers | ||
git checkout main | ||
cd examples/text_to_image | ||
pip install -r requirements.txt | ||
cd ../.. | ||
sudo pip install -e .' | ||
``` | ||
|
||
## Run the training job | ||
|
||
This script only trains the unet part of the network. The VAE and text encoder | ||
are fixed. | ||
|
||
```bash | ||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ | ||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \ | ||
--command=' | ||
export XLA_DISABLE_FUNCTIONALIZATION=1 | ||
export PROFILE_DIR=/tmp/profile # Update the directory to store profiles if needed. | ||
export CACHE_DIR=/tmp/xla_cache # Update the cache to store compiled XLA graphs if needed. | ||
export DATASET_NAME=lambdalabs/naruto-blip-captions | ||
export PER_HOST_BATCH_SIZE=16 # This is know to work on TPU v4. Can set this to 64 for TPU v5p. | ||
export TRAIN_STEPS=50 | ||
export OUTPUT_DIR=/tmp/output/ | ||
python diffusers/examples/text_to_image/train_text_to_image_xla.py \ | ||
--pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base \ | ||
--dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip \ | ||
--train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS \ | ||
--learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 \ | ||
--output_dir=$OUTPUT_DIR --dataloader_num_workers=4 \ | ||
--loader_prefetch_size=4 --device_prefetch_size=4' | ||
``` | ||
|
||
### Environment Envs Explained | ||
|
||
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer. | ||
* `PROFILE_DIR`: Specify where to put the profiling results. | ||
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching. | ||
* `DATASET_NAME`: Dataset to train the model. | ||
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis. | ||
* `TRAIN_STEPS`: Total number of training steps to run the training for. | ||
* `OUTPUT_DIR`: Directory to store the fine-tuned model. | ||
|
||
## Run inference using the output model | ||
|
||
To run inference using the output, you can simply load the model and pass it | ||
input prompts: | ||
|
||
```python | ||
import torch | ||
import os | ||
import sys | ||
import numpy as np | ||
|
||
import torch_xla.core.xla_model as xm | ||
from time import time | ||
from typing import Tuple | ||
from diffusers import StableDiffusionPipeline | ||
|
||
def main(args): | ||
device = xm.xla_device() | ||
model_path = <output_dir> | ||
pipe = StableDiffusionPipeline.from_pretrained( | ||
model_path, | ||
torch_dtype=torch.bfloat16 | ||
) | ||
pipe.to(device) | ||
prompt = ["A naruto with green eyes and red legs."] | ||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] | ||
image.save("naruto.png") | ||
|
||
if __name__ == '__main__': | ||
main() | ||
``` |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.