From 9a911f71958182bc5e8bef46652014e3062aba8b Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 22 Nov 2024 18:41:45 +0000 Subject: [PATCH] add maxdiffusion sdxl on trillium --- .../DIffusion-XL-MaxDiffusion/README.md | 52 +++++++++++++++++++ .../sdxl-2xv6e-256-pbds-1.sh | 7 +++ .../sdxl-v6e-256-pbds-1.sh | 10 ++++ training/trillium/MAXDIFFUSION_README.md | 31 +++++++++++ 4 files changed, 100 insertions(+) create mode 100644 training/trillium/DIffusion-XL-MaxDiffusion/README.md create mode 100644 training/trillium/DIffusion-XL-MaxDiffusion/sdxl-2xv6e-256-pbds-1.sh create mode 100644 training/trillium/DIffusion-XL-MaxDiffusion/sdxl-v6e-256-pbds-1.sh create mode 100644 training/trillium/MAXDIFFUSION_README.md diff --git a/training/trillium/DIffusion-XL-MaxDiffusion/README.md b/training/trillium/DIffusion-XL-MaxDiffusion/README.md new file mode 100644 index 0000000..0c88177 --- /dev/null +++ b/training/trillium/DIffusion-XL-MaxDiffusion/README.md @@ -0,0 +1,52 @@ +# Instructions for training MaxDiffusion SDXL on TPU trillium + +## XPK setup +Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/training/trillium/XPK_README.md) to create your GKE cluster with XPK + +## Prep for Maxdiffusion +Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/training/trillium/MAXDIFFUSION_README.md) to install maxdiffusion and build docker image + +Download pretrained stable_xl_base from [huggingface](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main) +##### `gsutil -m cp stable-diffusion-xl-base-1.0 ${OUTPUT_DIR}/checkpoints ` + +Prepare dataset and store at local folder + +`python -m unittest input_pipeline_interface_test.InputPipelineInterface.test_make_pokemon_iterator_sdxl_cache` + +Upload prepared dataset to gcs location + +`gsutil -m cp /tmp/pokemon-gpt4-captions_xl ${OUTPUT_DIR}/dataset ` +## Run Maxdiffusion SDXL workloads on GKE + +### Test Env +jaxlib=0.4.35 + +[maxdiffusion](https://github.com/AI-Hypercomputer/maxdiffusion.git)@269b6216ac65adb9e7044ec454879dc99856d5e9 + +### Starting workload + +From the maxdiffusion root directory, start your SDXL workload on v6e-256 + +``` +python3 ~/xpk/xpk.py workload create --cluster $CLUSTER_NAME --workload $USER-maxdiffusion --command "bash sdxl-v6e-256-pbds-1.sh $OUT_DIR" \ +--base-docker-image=maxdiffusion_base_image \ +--tpu-type=v6e-256 --num-slices=1 --zone=$ZONE --project=$PROJECT_ID +``` + +From your workload logs, you should start seeing step time logs like the following: +``` +completed step: 254, seconds: 0.164, TFLOP/s/device: 123.764, loss: 0.055 +``` + +start your SDXL workload on multi-slices of v6e-256 + +``` +python3 ~/xpk/xpk.py workload create --cluster $CLUSTER_NAME --workload $USER-maxdiffusion --command "bash sdxl-2xv6e-256-pbds-1.sh $OUT_DIR" \ +--base-docker-image=maxdiffusion_base_image \ +--tpu-type=v6e-256 --num-slices=2 --zone=$ZONE --project=$PROJECT_ID +``` + +From your workload logs, you should start seeing step time logs like the following: +``` +completed step: 92, seconds: 0.228, TFLOP/s/device: 89.120, loss: 0.057 +``` \ No newline at end of file diff --git a/training/trillium/DIffusion-XL-MaxDiffusion/sdxl-2xv6e-256-pbds-1.sh b/training/trillium/DIffusion-XL-MaxDiffusion/sdxl-2xv6e-256-pbds-1.sh new file mode 100644 index 0000000..6b72e91 --- /dev/null +++ b/training/trillium/DIffusion-XL-MaxDiffusion/sdxl-2xv6e-256-pbds-1.sh @@ -0,0 +1,7 @@ +export JAX_PLATFORMS="tpu,cpu" + +checkpoints=${OUTPUT_DIR}/checkpoints +dataset_path=${OUTPUT_DIR}/dataset + +ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml pretrained_model_name_or_path=${checkpoints}/models--stabilityai--stable-diffusion-xl-base-1.0 \ +revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 dataset_name=${dataset_path}/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=100 attention=flash run_name=trillium-sdxl enable_profiler=True output_dir=${OUT_DIR} diff --git a/training/trillium/DIffusion-XL-MaxDiffusion/sdxl-v6e-256-pbds-1.sh b/training/trillium/DIffusion-XL-MaxDiffusion/sdxl-v6e-256-pbds-1.sh new file mode 100644 index 0000000..3100038 --- /dev/null +++ b/training/trillium/DIffusion-XL-MaxDiffusion/sdxl-v6e-256-pbds-1.sh @@ -0,0 +1,10 @@ +export JAX_PLATFORMS="tpu,cpu" + +export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_spmd_threshold_for_allgather_cse=1000000 --xla_jf_spmd_threshold_for_windowed_einsum_mib=1000000' +LIBTPU_INIT_ARGS+=' --xla_sc_disable_megacore_partitioning=true --xla_tpu_use_tc_device_shape_on_sc=true --tpu_use_continuations=true --xla_sc_enable_instruction_fusion=false --xla_sc_disjoint_spmem=false --2a886c8_chip_config_name=megachip_tccontrol --xla_jf_crs_combiner_threshold_count=10 --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true' + +checkpoints=${OUTPUT_DIR}/checkpoints +dataset_path=${OUTPUT_DIR}/dataset + +ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml pretrained_model_name_or_path=${checkpoints}/models--stabilityai--stable-diffusion-xl-base-1.0 \ +revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 dataset_name=${dataset_path}/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=100 attention=flash run_name=trillium-sdxl enable_profiler=True output_dir=${OUT_DIR} diff --git a/training/trillium/MAXDIFFUSION_README.md b/training/trillium/MAXDIFFUSION_README.md new file mode 100644 index 0000000..236bdc0 --- /dev/null +++ b/training/trillium/MAXDIFFUSION_README.md @@ -0,0 +1,31 @@ +# Prep for Maxdiffusion workloads on GKE +1. Clone [maxdiffusion](https://github.com/AI-Hypercomputer/maxdiffusion.git) repo and move to its directory +``` +git clone https://github.com/AI-Hypercomputer/maxdiffusion.git +cd maxdiffusion +git checkout ${maxdiffusion_HASH} +``` + +2. Run the following commands to build the docker image +``` +bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 +``` + +3. Upload your docker image to Container Registry +``` +bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner +``` + +4. Create your GCS bucket +``` +OUTPUT_DIR=gs://v6e-demo-run # +gcloud storage buckets create ${OUTPUT_DIR} --project ${PROJECT} +``` + +5. Specify your workload configs +``` +export PROJECT=# +export ZONE=# +export CLUSTER_NAME=v6e-demo # +export OUTPUT_DIR=gs://v6e-demo/ # +``` \ No newline at end of file