From 1498d3dfc2cd363d4061de3c45495e92232144e6 Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Mon, 3 Jun 2024 19:50:08 +0000 Subject: [PATCH 1/7] AQT Einsum Dot General --- requirements.txt | 3 +- src/maxdiffusion/configs/base_xl.yml | 7 +- src/maxdiffusion/max_utils.py | 4 +- src/maxdiffusion/models/attention_flax.py | 34 +- src/maxdiffusion/models/quantizations.py | 131 ++++++++ .../models/unet_2d_blocks_flax.py | 23 +- .../models/unet_2d_condition_flax.py | 12 +- .../pipelines/pipeline_flax_utils.py | 2 + src/maxdiffusion/pyconfig.py | 21 +- src/maxdiffusion/unet_quantization.py | 295 ++++++++++++++++++ src/maxdiffusion/unet_quantization_utils.py | 7 + 11 files changed, 529 insertions(+), 10 deletions(-) create mode 100644 src/maxdiffusion/models/quantizations.py create mode 100644 src/maxdiffusion/unet_quantization.py create mode 100644 src/maxdiffusion/unet_quantization_utils.py diff --git a/requirements.txt b/requirements.txt index 0cf1fbae..b76d4787 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ google-cloud-storage absl-py transformers>=4.25.1 datasets -flax +flax>=0.8.1 +aqtp optax torch torchvision diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index d68016e1..d0c23a52 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -154,4 +154,9 @@ lightning_repo: "" # Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. lightning_ckpt: "" -enable_mllog: False \ No newline at end of file +enable_mllog: False + +quantization: 'int8' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 51f9fe8d..45844fde 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -349,10 +349,12 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True): # Needed to initialize weights on multi-host with addressable devices. + quant_enabled = config.quantization is not None if config.train_new_unet: unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["eval_only"])(rng, eval_only=False) else: - unet_variables = pipeline.unet.init_weights(rng, eval_only=True) + #unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["quantization_enabled"])(rng, quantization_enabled=quant_enabled) + unet_variables = pipeline.unet.init_weights(rng, eval_only=True, quant_enabled=quant_enabled) unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training) if config.train_new_unet: diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b53d6f9b..4e5f368a 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -23,6 +23,8 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from ..import common_types, max_logging +from . import quantizations + Array = common_types.Array Mesh = common_types.Mesh @@ -36,6 +38,11 @@ HEAD = common_types.HEAD D_KV = common_types.D_KV +Quant = quantizations.AqtQuantization + +def _maybe_aqt_einsum(quant: Quant): + return jnp.einsum if quant is None else quant.einsum() + class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str @@ -48,6 +55,7 @@ class AttentionOp(nn.Module): flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None + quant: Quant = None dtype: DType = jnp.float32 def check_attention_inputs( @@ -385,6 +393,9 @@ class FlaxAttention(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. + """ query_dim: int @@ -402,6 +413,7 @@ class FlaxAttention(nn.Module): key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + quant: Quant = None def setup(self): @@ -421,7 +433,8 @@ def setup(self): use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, flash_block_sizes=self.flash_block_sizes, - dtype=self.dtype + dtype=self.dtype, + quant=self.quant ) qkv_init_kernel = nn.with_logical_partitioning( @@ -429,9 +442,14 @@ def setup(self): ("embed","heads") ) + dot_general_cls = None + if self.quant: + dot_general_cls = self.quant.dot_general_cls() + self.query = nn.Dense( inner_dim, kernel_init=qkv_init_kernel, + dot_general_cls=dot_general_cls, use_bias=False, dtype=self.dtype, name="to_q" @@ -440,6 +458,7 @@ def setup(self): self.key = nn.Dense( inner_dim, kernel_init=qkv_init_kernel, + dot_general_cls=dot_general_cls, use_bias=False, dtype=self.dtype, name="to_k" @@ -448,6 +467,7 @@ def setup(self): self.value = nn.Dense( inner_dim, kernel_init=qkv_init_kernel, + dot_general_cls=dot_general_cls, use_bias=False, dtype=self.dtype, name="to_v") @@ -458,6 +478,7 @@ def setup(self): nn.initializers.lecun_normal(), ("heads","embed") ), + dot_general_cls=dot_general_cls, dtype=self.dtype, name="to_out_0") self.dropout_layer = nn.Dropout(rate=self.dropout) @@ -520,6 +541,8 @@ class FlaxBasicTransformerBlock(nn.Module): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ dim: int n_heads: int @@ -533,6 +556,7 @@ class FlaxBasicTransformerBlock(nn.Module): flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None + quant: Quant = None def setup(self): # self attention (or cross_attention if only_cross_attention is True) @@ -548,6 +572,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) # cross attention self.attn2 = FlaxAttention( @@ -562,6 +587,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -625,6 +651,8 @@ class FlaxTransformer2DModel(nn.Module): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int n_heads: int @@ -641,6 +669,7 @@ class FlaxTransformer2DModel(nn.Module): flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None norm_num_groups: int = 32 + quant: Quant = None def setup(self): self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5) @@ -679,7 +708,8 @@ def setup(self): attention_kernel=self.attention_kernel, flash_min_seq_length=self.flash_min_seq_length, flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh + mesh=self.mesh, + quant=self.quant ) for _ in range(self.depth) ] diff --git a/src/maxdiffusion/models/quantizations.py b/src/maxdiffusion/models/quantizations.py new file mode 100644 index 00000000..f6792b3b --- /dev/null +++ b/src/maxdiffusion/models/quantizations.py @@ -0,0 +1,131 @@ +""" + Copyright 2024 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import functools + +from aqt.jax.v2 import config as aqt_config +from aqt.jax.v2.flax import aqt_flax +from ..common_types import Config +from dataclasses import dataclass +import jax.numpy as jnp + + +@dataclass +class AqtQuantization: + """ Configures AQT quantization github.com/google/aqt. """ + quant_dg: aqt_config.DotGeneral + lhs_quant_mode: aqt_flax.QuantMode + rhs_quant_mode: aqt_flax.QuantMode + + def dot_general_cls(self): + """ Returns dot_general configured with aqt params. """ + aqt_dg_cls = functools.partial( + aqt_flax.AqtDotGeneral, + self.quant_dg, + lhs_quant_mode=self.lhs_quant_mode, + rhs_quant_mode=self.rhs_quant_mode, + ) + return aqt_dg_cls + + def einsum(self): + """ Returns einsum configured with aqt params """ + aqt_einsum = functools.partial(aqt_flax.AqtEinsum( + cfg=self.quant_dg, + lhs_quant_mode=self.lhs_quant_mode, + rhs_quant_mode=self.rhs_quant_mode, + ) + ) + return aqt_einsum + +def _get_quant_config(config): + if not config.quantization or config.quantization == '': + return None + elif config.quantization == "int8": + if config.quantization_local_shard_count == 0: + drhs_bits = None + drhs_accumulator_dtype = None + drhs_local_aqt=None + else: + drhs_bits = 8 + drhs_accumulator_dtype = jnp.int32 + drhs_local_aqt = aqt_config.LocalAqt(config.quantization_local_shard_count) + return aqt_config.config_v3( + fwd_bits=8, + dlhs_bits=8, + drhs_bits=drhs_bits, + rng_type='jax.uniform', + dlhs_local_aqt=None, + drhs_local_aqt=drhs_local_aqt, + fwd_accumulator_dtype=jnp.int32, + dlhs_accumulator_dtype=jnp.int32, + drhs_accumulator_dtype=drhs_accumulator_dtype, + ) + else: + raise ValueError(f'Invalid value configured for quantization {config.quantization}.') + +def in_convert_mode(quant): + return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT) + +def in_serve_mode(quant): + return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE) + +def get_quant_mode(quant_mode_str: str = 'train'): + """ Set quant mode.""" + if quant_mode_str == 'train': + return aqt_flax.QuantMode.TRAIN + elif quant_mode_str == 'serve': + return aqt_flax.QuantMode.SERVE + elif quant_mode_str == 'convert': + return aqt_flax.QuantMode.CONVERT + else: + raise ValueError(f'Invalid quantization mode {quant_mode_str}.') + return None + +def configure_quantization(config: Config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.TRAIN): + """ Configure quantization based on user config and quant mode.""" + quant_cfg = _get_quant_config(config) + if quant_cfg: + return AqtQuantization(quant_dg=quant_cfg, lhs_quant_mode=lhs_quant_mode, rhs_quant_mode=rhs_quant_mode) + return None + +# @dataclass +# class AqtQuantization: +# """ Configures AQT quantization github.com/google/aqt. """ +# quant_dg: aqt_config.DotGeneral +# quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN + + + + +# def dot_general_cls_aqt(self, aqt_cfg, lhs_quant_mode, rhs_quant_mode): +# """ Returns dot_general configured with aqt params. """ +# aqt_dg_cls = functools.partial( +# aqt_flax.AqtDotGeneral, +# aqt_cfg, +# lhs_quant_mode=lhs_quant_mode, +# rhs_quant_mode=rhs_quant_mode, +# lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, +# rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, +# ) +# return aqt_dg_cls + +# def einsum_aqt(self, aqt_cfg, lhs_quant_mode, rhs_quant_mode): +# return functools.partial( +# aqt_flax.AqtEinsum, +# aqt_cfg, +# lhs_quant_mode=lhs_quant_mode, +# rhs_quant_mode=rhs_quant_mode, +# lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, +# rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, +# ) + \ No newline at end of file diff --git a/src/maxdiffusion/models/unet_2d_blocks_flax.py b/src/maxdiffusion/models/unet_2d_blocks_flax.py index d9ab86c2..3aa1e5cc 100644 --- a/src/maxdiffusion/models/unet_2d_blocks_flax.py +++ b/src/maxdiffusion/models/unet_2d_blocks_flax.py @@ -18,6 +18,10 @@ from .attention_flax import FlaxTransformer2DModel from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D +from . import quantizations + +Quant = quantizations.AqtQuantization + from ..common_types import BlockSizes class FlaxCrossAttnDownBlock2D(nn.Module): @@ -53,6 +57,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int out_channels: int @@ -69,6 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None dtype: jnp.dtype = jnp.float32 + quant: Quant = None transformer_layers_per_block: int = 1 norm_num_groups: int = 32 @@ -102,7 +109,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, - norm_num_groups=self.norm_num_groups + norm_num_groups=self.norm_num_groups, + quant=self.quant, ) attentions.append(attn_block) @@ -219,6 +227,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int out_channels: int @@ -238,6 +248,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 norm_num_groups: int = 32 + quant: Quant = None + def setup(self): resnets = [] @@ -270,7 +282,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, - norm_num_groups=self.norm_num_groups + norm_num_groups=self.norm_num_groups, + quant=self.quant, ) attentions.append(attn_block) @@ -389,6 +402,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ in_channels: int dropout: float = 0.0 @@ -404,6 +419,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 norm_num_groups: int = 32 + quant=None def setup(self): # there is always at least one resnet @@ -433,7 +449,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, - norm_num_groups=self.norm_num_groups + norm_num_groups=self.norm_num_groups, + quant=self.quant, ) attentions.append(attn_block) diff --git a/src/maxdiffusion/models/unet_2d_condition_flax.py b/src/maxdiffusion/models/unet_2d_condition_flax.py index 0c6e7f73..7c28732b 100644 --- a/src/maxdiffusion/models/unet_2d_condition_flax.py +++ b/src/maxdiffusion/models/unet_2d_condition_flax.py @@ -31,8 +31,10 @@ FlaxUpBlock2D, ) +from . import quantizations from ..common_types import BlockSizes +Quant = quantizations.AqtQuantization @flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): @@ -105,6 +107,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. + quant (`AqtQuantization`, *optional*, defaults to None) + Configures AQT quantization github.com/google/aqt. """ sample_size: int = 32 @@ -140,8 +144,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: Optional[int] = None norm_num_groups: int = 32 + quant: Quant = None - def init_weights(self, rng: jax.Array, eval_only: bool = False) -> FrozenDict: + def init_weights(self, rng: jax.Array, eval_only: bool = False, quantization_enabled: bool = False) -> FrozenDict: # init input tensors no_devices = jax.device_count() sample_shape = (no_devices, self.in_channels, self.sample_size, self.sample_size) @@ -151,6 +156,8 @@ def init_weights(self, rng: jax.Array, eval_only: bool = False) -> FrozenDict: params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} + if quantization_enabled: + rngs["aqt"] = params_rng added_cond_kwargs = None if self.addition_embed_type == "text_time": @@ -260,6 +267,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) else: down_block = FlaxDownBlock2D( @@ -288,6 +296,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) # up @@ -323,6 +332,7 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, dtype=self.dtype, + quant=self.quant, ) else: up_block = FlaxUpBlock2D( diff --git a/src/maxdiffusion/pipelines/pipeline_flax_utils.py b/src/maxdiffusion/pipelines/pipeline_flax_utils.py index b28450fa..5a2fd214 100644 --- a/src/maxdiffusion/pipelines/pipeline_flax_utils.py +++ b/src/maxdiffusion/pipelines/pipeline_flax_utils.py @@ -329,6 +329,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P mesh = kwargs.pop("mesh", None) dtype = kwargs.pop("dtype", None) norm_num_groups = kwargs.pop("norm_num_groups", 32) + quant = kwargs.pop("quant", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -513,6 +514,7 @@ def load_module(name, value): mesh=mesh, norm_num_groups=norm_num_groups, dtype=dtype, + quant=quant, ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index eed27f76..8cedb99c 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -111,13 +111,32 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) - + raw_keys["num_slices"] = get_num_slices(raw_keys) + raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if raw_keys["learning_rate_schedule_steps"]==-1: raw_keys["learning_rate_schedule_steps"] = raw_keys["max_train_steps"] if "gs://" in raw_keys["pretrained_model_name_or_path"]: raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") +def get_num_slices(raw_keys): + if int(raw_keys['compile_topology_num_slices']) > 0: + return raw_keys['compile_topology_num_slices'] + else: + devices = jax.devices() + try: + return 1 + max([d.slice_index for d in devices]) + except: + return 1 + +def get_quantization_local_shard_count(raw_keys): + if raw_keys['quantization_local_shard_count'] == -1: + return raw_keys['num_slices'] + else: + return raw_keys['quantization_local_shard_count'] + + def get_num_target_devices(raw_keys): return len(jax.devices()) diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py new file mode 100644 index 00000000..d0821ebd --- /dev/null +++ b/src/maxdiffusion/unet_quantization.py @@ -0,0 +1,295 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import functools +from absl import app +from typing import Sequence +import time + +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from jax.sharding import PartitionSpec as P +from jax.experimental.compilation_cache import compilation_cache as cc +from flax.linen import partitioning as nn_partitioning +from jax.sharding import PositionalSharding + +from maxdiffusion import ( + FlaxStableDiffusionXLPipeline, + FlaxEulerDiscreteScheduler, + FlaxDDPMScheduler +) + + +from maxdiffusion import pyconfig +from maxdiffusion.image_processor import VaeImageProcessor +from maxdiffusion.max_utils import ( + create_device_mesh, + get_dtype, + get_states, + activate_profiler, + deactivate_profiler, + device_put_replicated, + get_flash_block_sizes, +) +from maxdiffusion.maxdiffusion_utils import ( + load_sdxllightning_unet, + get_add_time_ids, + rescale_noise_cfg +) +from maxdiffusion.models import quantizations + +cc.set_cache_dir(os.path.expanduser("~/jax_cache")) + +def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): + latents, scheduler_state, state = args + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) + noise_pred = model.apply( + {"params" : state.params}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs + ).sample + + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale=guidance_rescale) + + + latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + return latents, scheduler_state, state + +def get_embeddings(prompt_ids, pipeline, params): + te_1_inputs = prompt_ids[:, 0, :] + te_2_inputs = prompt_ids[:, 1, :] + + prompt_embeds = pipeline.text_encoder( + te_1_inputs, params=params["text_encoder"], output_hidden_states=True + ) + prompt_embeds = prompt_embeds["hidden_states"][-2] + prompt_embeds_2_out = pipeline.text_encoder_2( + te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True + ) + prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] + text_embeds = prompt_embeds_2_out["text_embeds"] + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) + return prompt_embeds, text_embeds + +def tokenize(prompt, pipeline): + inputs = [] + for _tokenizer in [pipeline.tokenizer, pipeline.tokenizer_2]: + text_inputs = _tokenizer( + prompt, + padding="max_length", + max_length=_tokenizer.model_max_length, + truncation=True, + return_tensors="np" + ) + inputs.append(text_inputs.input_ids) + inputs = jnp.stack(inputs,axis=1) + return inputs + +def run(config): + rng = jax.random.PRNGKey(config.seed) + + # Setup Mesh + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size = config.per_device_batch_size * jax.device_count() + + weight_dtype = get_dtype(config) + flash_block_sizes = get_flash_block_sizes(config) + + quant = quantizations.configure_quantization(config=config, aqt_flax.QuantMode.TRAIN, aqt_flax.QuantMode.CONVERT) + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=weight_dtype, + split_head_dim=config.split_head_dim, + norm_num_groups=config.norm_num_groups, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + quant=quant, + ) + + # if this checkpoint was trained with maxdiffusion + # the training scheduler was saved with it, switch it + # to a Euler scheduler + if isinstance(pipeline.scheduler, FlaxDDPMScheduler): + noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, subfolder="scheduler", dtype=jnp.float32 + ) + pipeline.scheduler = noise_scheduler + params["scheduler"] = noise_scheduler_state + + if config.lightning_repo: + pipeline, params = load_sdxllightning_unet(config, pipeline, params) + + scheduler_state = params.pop("scheduler") + old_params = params + params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), old_params) + params["scheduler"] = scheduler_state + + data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding)) + + sharding = PositionalSharding(devices_array).replicate() + partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) + params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) + params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) + + unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, params["unet"], params["vae"], training=False) + del params["vae"] + del params["unet"] + + def get_unet_inputs(rng, config, batch_size, pipeline, params): + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + prompt_ids = [config.prompt] * batch_size + prompt_ids = tokenize(prompt_ids, pipeline) + negative_prompt_ids = [config.negative_prompt] * batch_size + negative_prompt_ids = tokenize(negative_prompt_ids, pipeline) + guidance_scale = config.guidance_scale + guidance_rescale = config.guidance_rescale + num_inference_steps = config.num_inference_steps + height = config.resolution + width = config.resolution + prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, params) + batch_size = prompt_embeds.shape[0] + negative_prompt_embeds, negative_pooled_embeds = get_embeddings(negative_prompt_ids, pipeline, params) + add_time_ids = get_add_time_ids( + (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype + ) + + prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + guidance_rescale = jnp.array([guidance_rescale], dtype=jnp.float32) + + latents_shape = ( + batch_size, + pipeline.unet.config.in_channels, + height // vae_scale_factor, + width // vae_scale_factor, + ) + + latents = jax.random.normal(rng, shape=latents_shape, dtype=jnp.float32) + + scheduler_state = pipeline.scheduler.set_timesteps( + params["scheduler"], + num_inference_steps=num_inference_steps, + shape=latents.shape + ) + + latents = latents * scheduler_state.init_noise_sigma + + added_cond_kwargs = {"text_embeds" : add_text_embeds, "time_ids" : add_time_ids} + latents = jax.device_put(latents, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + guidance_scale = jax.device_put(guidance_scale, PositionalSharding(devices_array).replicate()) + added_cond_kwargs['text_embeds'] = jax.device_put(added_cond_kwargs['text_embeds'], data_sharding) + added_cond_kwargs['time_ids'] = jax.device_put(added_cond_kwargs['time_ids'], data_sharding) + + return latents, prompt_embeds, added_cond_kwargs, guidance_scale, guidance_rescale, scheduler_state + + def vae_decode(latents, state, pipeline): + latents = 1 / pipeline.vae.config.scaling_factor * latents + image = pipeline.vae.apply( + {"params" : state.params}, + latents, + method=pipeline.vae.decode + ).sample + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeline): + + (latents, + prompt_embeds, + added_cond_kwargs, + guidance_scale, + guidance_rescale, + scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) + + loop_body_p = functools.partial(loop_body, model=pipeline.unet, + pipeline=pipeline, + added_cond_kwargs=added_cond_kwargs, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + guidance_rescale=guidance_rescale) + vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, + loop_body_p, (latents, scheduler_state, unet_state)) + image = vae_decode_p(latents, vae_state) + return image + + p_run_inference = jax.jit( + functools.partial(run_inference, rng=rng, config=config, batch_size=batch_size, pipeline=pipeline), + in_shardings=(unet_state_mesh_shardings, vae_state_mesh_shardings, None), + out_shardings=None + ) + + s = time.time() + p_run_inference(unet_state, vae_state, params).block_until_ready() + print("compile time: ", (time.time() - s)) + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() + images.block_until_ready() + print("inference time: ",(time.time() - s)) + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() #run_inference(unet_state, vae_state, latents, scheduler_state) + images.block_until_ready() + print("inference time: ",(time.time() - s)) + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() # run_inference(unet_state, vae_state, latents, scheduler_state) + images.block_until_ready() + print("inference time: ",(time.time() - s)) + s = time.time() + activate_profiler(config) + images = p_run_inference(unet_state, vae_state, params).block_until_ready() + deactivate_profiler(config) + images.block_until_ready() + print("inference time: ",(time.time() - s)) + images = jax.experimental.multihost_utils.process_allgather(images) + numpy_images = np.array(images) + images = VaeImageProcessor.numpy_to_pil(numpy_images) + for i, image in enumerate(images): + image.save(f"image_sdxl_{i}.png") + + return images + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/unet_quantization_utils.py b/src/maxdiffusion/unet_quantization_utils.py new file mode 100644 index 00000000..6589ddbb --- /dev/null +++ b/src/maxdiffusion/unet_quantization_utils.py @@ -0,0 +1,7 @@ +# initialize in serve mode - No calibration +# _get_quantized_vars - vars are quantized + +call loop body 1 step and save the vars. + +That is Unet quantized vars. +apply this unet vars to unet model and the call serve on quantized. \ No newline at end of file From dde5d06edccec8c98751244c5c8d6bc2a4779d25 Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Mon, 3 Jun 2024 20:22:11 +0000 Subject: [PATCH 2/7] AQT wip --- src/maxdiffusion/unet_quantization.py | 47 ++++++++++++++++++++- src/maxdiffusion/unet_quantization_utils.py | 7 --- 2 files changed, 46 insertions(+), 8 deletions(-) delete mode 100644 src/maxdiffusion/unet_quantization_utils.py diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py index d0821ebd..52ad6634 100644 --- a/src/maxdiffusion/unet_quantization.py +++ b/src/maxdiffusion/unet_quantization.py @@ -28,6 +28,8 @@ from jax.experimental.compilation_cache import compilation_cache as cc from flax.linen import partitioning as nn_partitioning from jax.sharding import PositionalSharding +from aqt.jax.v2.flax import aqt_flax + from maxdiffusion import ( FlaxStableDiffusionXLPipeline, @@ -46,6 +48,8 @@ deactivate_profiler, device_put_replicated, get_flash_block_sizes, + get_abstract_state, + setup_initial_state ) from maxdiffusion.maxdiffusion_utils import ( load_sdxllightning_unet, @@ -83,6 +87,24 @@ def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, gui return latents, scheduler_state, state +def loop_body_for_quantization(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): + latents, scheduler_state, state = args + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) + noise_pred, quantized_unet_vars = model.apply( + {"params" : state.params}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs + ) + return quantized_unet_vars + + def get_embeddings(prompt_ids, pipeline, params): te_1_inputs = prompt_ids[:, 0, :] te_2_inputs = prompt_ids[:, 1, :] @@ -125,7 +147,7 @@ def run(config): weight_dtype = get_dtype(config) flash_block_sizes = get_flash_block_sizes(config) - quant = quantizations.configure_quantization(config=config, aqt_flax.QuantMode.TRAIN, aqt_flax.QuantMode.CONVERT) + quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.CONVERT) pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( config.pretrained_model_name_or_path, revision=config.revision, @@ -238,6 +260,27 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale, scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) + loop_body_quant_p = functools.partial(loop_body, model=pipeline.unet, + pipeline=pipeline, + added_cond_kwargs=added_cond_kwargs, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + guidance_rescale=guidance_rescale) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + quantized_unet_vars = jax.lax.fori_loop(0, 1, + loop_body_p, (latents, scheduler_state, unet_state)) + unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, None, config, mesh, quantized_unet_vars, training=False) + unet_state, unet_state_mesh_shardings = setup_initial_state( + pipeline.unet, + None, + config, + mesh, + quantized_unet_vars, + unboxed_abstract_state, + state_mesh_annotations, + training=False) + + loop_body_p = functools.partial(loop_body, model=pipeline.unet, pipeline=pipeline, added_cond_kwargs=added_cond_kwargs, @@ -246,6 +289,8 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale=guidance_rescale) vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, scheduler_state, unet_state)) diff --git a/src/maxdiffusion/unet_quantization_utils.py b/src/maxdiffusion/unet_quantization_utils.py deleted file mode 100644 index 6589ddbb..00000000 --- a/src/maxdiffusion/unet_quantization_utils.py +++ /dev/null @@ -1,7 +0,0 @@ -# initialize in serve mode - No calibration -# _get_quantized_vars - vars are quantized - -call loop body 1 step and save the vars. - -That is Unet quantized vars. -apply this unet vars to unet model and the call serve on quantized. \ No newline at end of file From 06166022bcc43fff42cec485cc41c7b838e55b06 Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Mon, 3 Jun 2024 21:18:39 +0000 Subject: [PATCH 3/7] wip AQT --- src/maxdiffusion/models/quantizations.py | 3 ++- src/maxdiffusion/unet_quantization.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/models/quantizations.py b/src/maxdiffusion/models/quantizations.py index f6792b3b..eb504672 100644 --- a/src/maxdiffusion/models/quantizations.py +++ b/src/maxdiffusion/models/quantizations.py @@ -58,8 +58,9 @@ def _get_quant_config(config): else: drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 + print(config.quantization_local_shard_count) # -1 drhs_local_aqt = aqt_config.LocalAqt(config.quantization_local_shard_count) - return aqt_config.config_v3( + return aqt_config.config_v4( fwd_bits=8, dlhs_bits=8, drhs_bits=drhs_bits, diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py index 52ad6634..f1f4bfc0 100644 --- a/src/maxdiffusion/unet_quantization.py +++ b/src/maxdiffusion/unet_quantization.py @@ -51,12 +51,12 @@ get_abstract_state, setup_initial_state ) -from maxdiffusion.maxdiffusion_utils import ( +from .maxdiffusion_utils import ( load_sdxllightning_unet, get_add_time_ids, rescale_noise_cfg ) -from maxdiffusion.models import quantizations +from .models import quantizations cc.set_cache_dir(os.path.expanduser("~/jax_cache")) @@ -260,7 +260,7 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale, scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) - loop_body_quant_p = functools.partial(loop_body, model=pipeline.unet, + loop_body_quant_p = functools.partial(loop_body_for_quantization, model=pipeline.unet, pipeline=pipeline, added_cond_kwargs=added_cond_kwargs, prompt_embeds=prompt_embeds, @@ -268,7 +268,7 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale=guidance_rescale) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): quantized_unet_vars = jax.lax.fori_loop(0, 1, - loop_body_p, (latents, scheduler_state, unet_state)) + loop_body_quant_p, (latents, scheduler_state, unet_state)) unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, None, config, mesh, quantized_unet_vars, training=False) unet_state, unet_state_mesh_shardings = setup_initial_state( pipeline.unet, @@ -289,7 +289,7 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale=guidance_rescale) vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) - + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, From c79ce9815acff6504f98b0bcbf05b60ce7d03e82 Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Tue, 4 Jun 2024 17:24:24 +0000 Subject: [PATCH 4/7] aqt wip --- src/maxdiffusion/models/quantizations.py | 3 ++- src/maxdiffusion/unet_quantization.py | 12 +++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/models/quantizations.py b/src/maxdiffusion/models/quantizations.py index eb504672..acb9c90c 100644 --- a/src/maxdiffusion/models/quantizations.py +++ b/src/maxdiffusion/models/quantizations.py @@ -14,6 +14,7 @@ import functools from aqt.jax.v2 import config as aqt_config + from aqt.jax.v2.flax import aqt_flax from ..common_types import Config from dataclasses import dataclass @@ -59,7 +60,7 @@ def _get_quant_config(config): drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 print(config.quantization_local_shard_count) # -1 - drhs_local_aqt = aqt_config.LocalAqt(config.quantization_local_shard_count) + drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count) return aqt_config.config_v4( fwd_bits=8, dlhs_bits=8, diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py index f1f4bfc0..9c6e58a1 100644 --- a/src/maxdiffusion/unet_quantization.py +++ b/src/maxdiffusion/unet_quantization.py @@ -88,7 +88,7 @@ def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, gui return latents, scheduler_state, state def loop_body_for_quantization(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): - latents, scheduler_state, state = args + latents, scheduler_state, state, rng = args latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] @@ -96,11 +96,13 @@ def loop_body_for_quantization(step, args, model, pipeline, added_cond_kwargs, p latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) noise_pred, quantized_unet_vars = model.apply( - {"params" : state.params}, + state.params | {"aqt" : {}}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=prompt_embeds, - added_cond_kwargs=added_cond_kwargs + added_cond_kwargs=added_cond_kwargs, + rngs={"params": rng}, + mutable=True, ) return quantized_unet_vars @@ -267,8 +269,8 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_scale=guidance_scale, guidance_rescale=guidance_rescale) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - quantized_unet_vars = jax.lax.fori_loop(0, 1, - loop_body_quant_p, (latents, scheduler_state, unet_state)) + quantized_unet_vars = loop_body_quant_p, (latents, scheduler_state, unet_state, rng) + unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, None, config, mesh, quantized_unet_vars, training=False) unet_state, unet_state_mesh_shardings = setup_initial_state( pipeline.unet, From e9af9db72549bc2ed581444cf3b13eecd197def4 Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Tue, 4 Jun 2024 21:02:28 +0000 Subject: [PATCH 5/7] wip AQT --- src/maxdiffusion/unet_quantization.py | 81 +++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 11 deletions(-) diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py index 9c6e58a1..c0114f4a 100644 --- a/src/maxdiffusion/unet_quantization.py +++ b/src/maxdiffusion/unet_quantization.py @@ -87,11 +87,11 @@ def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, gui return latents, scheduler_state, state -def loop_body_for_quantization(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): - latents, scheduler_state, state, rng = args +def loop_body_for_quantization(latents, scheduler_state, state, rng, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): + # latents, scheduler_state, state, rng = args latents_input = jnp.concatenate([latents] * 2) - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) @@ -252,6 +252,28 @@ def vae_decode(latents, state, pipeline): ).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image + + def get_quantized_unet_vars(unet_state, params, rng, config, batch_size, pipeline): + + (latents, + prompt_embeds, + added_cond_kwargs, + guidance_scale, + guidance_rescale, + scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) + + loop_body_quant_p = jax.jit(functools.partial(loop_body_for_quantization, + model=pipeline.unet, + pipeline=pipeline, + added_cond_kwargs=added_cond_kwargs, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + guidance_rescale=guidance_rescale)) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + quantized_unet_vars = loop_body_quant_p(latents=latents, scheduler_state=scheduler_state, state=unet_state,rng=rng) + + + return quantized_unet_vars def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeline): @@ -262,14 +284,16 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale, scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) - loop_body_quant_p = functools.partial(loop_body_for_quantization, model=pipeline.unet, - pipeline=pipeline, - added_cond_kwargs=added_cond_kwargs, - prompt_embeds=prompt_embeds, - guidance_scale=guidance_scale, - guidance_rescale=guidance_rescale) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - quantized_unet_vars = loop_body_quant_p, (latents, scheduler_state, unet_state, rng) + # loop_body_quant_p = jax.jit(functools.partial(loop_body_for_quantization, + # model=pipeline.unet, + # pipeline=pipeline, + # added_cond_kwargs=added_cond_kwargs, + # prompt_embeds=prompt_embeds, + # guidance_scale=guidance_scale, + # guidance_rescale=guidance_rescale)) + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + # quantized_unet_vars = loop_body_quant_p(latents=latents, scheduler_state=scheduler_state, state=unet_state,rng=rng) + unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, None, config, mesh, quantized_unet_vars, training=False) unet_state, unet_state_mesh_shardings = setup_initial_state( @@ -298,7 +322,42 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli loop_body_p, (latents, scheduler_state, unet_state)) image = vae_decode_p(latents, vae_state) return image + + quantized_unet_vars = get_quantized_unet_vars(unet_state, params, rng, config, batch_size, pipeline) + + del params + del pipeline + + quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.SERVE) + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=weight_dtype, + split_head_dim=config.split_head_dim, + norm_num_groups=config.norm_num_groups, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + quant=quant, + ) + + scheduler_state = params.pop("scheduler") + old_params = params + params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), old_params) + params["scheduler"] = scheduler_state + + data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding)) + sharding = PositionalSharding(devices_array).replicate() + partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) + params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) + params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) + + unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, quantized_unet_vars, params["vae"], training=False) + del params["vae"] + del params["unet"] + + p_run_inference = jax.jit( functools.partial(run_inference, rng=rng, config=config, batch_size=batch_size, pipeline=pipeline), in_shardings=(unet_state_mesh_shardings, vae_state_mesh_shardings, None), From faf9a4f0ee0c4e9e439cbb24978348cb4c35be8a Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Wed, 5 Jun 2024 00:35:00 +0000 Subject: [PATCH 6/7] wip aqt --- src/maxdiffusion/max_utils.py | 2 +- src/maxdiffusion/models/quantizations.py | 4 +++ .../models/unet_2d_blocks_flax.py | 2 +- src/maxdiffusion/unet_quantization.py | 35 ++++--------------- 4 files changed, 12 insertions(+), 31 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 45844fde..7df15eb9 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -354,7 +354,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["eval_only"])(rng, eval_only=False) else: #unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["quantization_enabled"])(rng, quantization_enabled=quant_enabled) - unet_variables = pipeline.unet.init_weights(rng, eval_only=True, quant_enabled=quant_enabled) + unet_variables = pipeline.unet.init_weights(rng, eval_only=True, quantization_enabled=quant_enabled) unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training) if config.train_new_unet: diff --git a/src/maxdiffusion/models/quantizations.py b/src/maxdiffusion/models/quantizations.py index acb9c90c..b2fd6c75 100644 --- a/src/maxdiffusion/models/quantizations.py +++ b/src/maxdiffusion/models/quantizations.py @@ -35,6 +35,8 @@ def dot_general_cls(self): self.quant_dg, lhs_quant_mode=self.lhs_quant_mode, rhs_quant_mode=self.rhs_quant_mode, + lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, + rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, ) return aqt_dg_cls @@ -44,6 +46,8 @@ def einsum(self): cfg=self.quant_dg, lhs_quant_mode=self.lhs_quant_mode, rhs_quant_mode=self.rhs_quant_mode, + lhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION, + rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, ) ) return aqt_einsum diff --git a/src/maxdiffusion/models/unet_2d_blocks_flax.py b/src/maxdiffusion/models/unet_2d_blocks_flax.py index 3aa1e5cc..8acd86b7 100644 --- a/src/maxdiffusion/models/unet_2d_blocks_flax.py +++ b/src/maxdiffusion/models/unet_2d_blocks_flax.py @@ -419,7 +419,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 norm_num_groups: int = 32 - quant=None + quant: Quant = None def setup(self): # there is always at least one resnet diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py index c0114f4a..79ebabe9 100644 --- a/src/maxdiffusion/unet_quantization.py +++ b/src/maxdiffusion/unet_quantization.py @@ -41,6 +41,7 @@ from maxdiffusion import pyconfig from maxdiffusion.image_processor import VaeImageProcessor from maxdiffusion.max_utils import ( + InferenceState, create_device_mesh, get_dtype, get_states, @@ -284,29 +285,6 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale, scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params) - # loop_body_quant_p = jax.jit(functools.partial(loop_body_for_quantization, - # model=pipeline.unet, - # pipeline=pipeline, - # added_cond_kwargs=added_cond_kwargs, - # prompt_embeds=prompt_embeds, - # guidance_scale=guidance_scale, - # guidance_rescale=guidance_rescale)) - # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - # quantized_unet_vars = loop_body_quant_p(latents=latents, scheduler_state=scheduler_state, state=unet_state,rng=rng) - - - unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, None, config, mesh, quantized_unet_vars, training=False) - unet_state, unet_state_mesh_shardings = setup_initial_state( - pipeline.unet, - None, - config, - mesh, - quantized_unet_vars, - unboxed_abstract_state, - state_mesh_annotations, - training=False) - - loop_body_p = functools.partial(loop_body, model=pipeline.unet, pipeline=pipeline, added_cond_kwargs=added_cond_kwargs, @@ -315,8 +293,6 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli guidance_rescale=guidance_rescale) vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, scheduler_state, unet_state)) @@ -328,7 +304,7 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli del params del pipeline - quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.SERVE) + quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.CONVERT) pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( config.pretrained_model_name_or_path, revision=config.revision, @@ -353,9 +329,10 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) - unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, quantized_unet_vars, params["vae"], training=False) - del params["vae"] - del params["unet"] + unet_state = InferenceState(pipeline.unet.apply, params=quantized_unet_vars) + # unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, quantized_unet_vars, params["vae"], training=False) + # del params["vae"] + # del params["unet"] p_run_inference = jax.jit( From d86a7ab8235e4344c67e357e70fe9bddc10d47e0 Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Fri, 7 Jun 2024 00:50:19 +0000 Subject: [PATCH 7/7] Working E2E Quantization --- src/maxdiffusion/max_utils.py | 14 +- src/maxdiffusion/unet_quantization.py | 183 ++++++++++++++++++++------ 2 files changed, 151 insertions(+), 46 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 7df15eb9..841a3b9b 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -330,6 +330,7 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_ init_train_state_partial = functools.partial(init_train_state, model=model, tx=tx, training=training) sharding = PositionalSharding(mesh.devices).replicate() + # TODO - Inspect structure of sharding? partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) model_params = jax.tree_util.tree_map(partial_device_put_replicated, model_params) @@ -344,9 +345,10 @@ def setup_initial_state(model, tx, config, mesh, model_params, unboxed_abstract_ state_mesh_shardings = jax.tree_util.tree_map( lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + return state, state_mesh_shardings -def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True): +def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, training=True, q_v=None): # Needed to initialize weights on multi-host with addressable devices. quant_enabled = config.quantization is not None @@ -355,8 +357,10 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin else: #unet_variables = jax.jit(pipeline.unet.init_weights, static_argnames=["quantization_enabled"])(rng, quantization_enabled=quant_enabled) unet_variables = pipeline.unet.init_weights(rng, eval_only=True, quantization_enabled=quant_enabled) - - unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training) + if q_v: + unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, q_v, training=training) + else: + unboxed_abstract_state, state_mesh_annotations = get_abstract_state(pipeline.unet, tx, config, mesh, unet_variables, training=training) if config.train_new_unet: unet_params = unet_variables else: @@ -366,7 +370,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin tx, config, mesh, - unet_params, + q_v, unboxed_abstract_state, state_mesh_annotations, training=training) @@ -383,7 +387,7 @@ def get_states(mesh, tx, rng, config, pipeline, unet_params, vae_params, trainin state_mesh_annotations, training=training ) - + # breakpoint() return unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings # Learning Rate Schedule diff --git a/src/maxdiffusion/unet_quantization.py b/src/maxdiffusion/unet_quantization.py index 79ebabe9..ff176103 100644 --- a/src/maxdiffusion/unet_quantization.py +++ b/src/maxdiffusion/unet_quantization.py @@ -19,6 +19,7 @@ from absl import app from typing import Sequence import time +from maxdiffusion.models.unet_2d_condition_flax import FlaxUNet2DConditionModel import numpy as np import jax @@ -29,7 +30,7 @@ from flax.linen import partitioning as nn_partitioning from jax.sharding import PositionalSharding from aqt.jax.v2.flax import aqt_flax - +import optax from maxdiffusion import ( FlaxStableDiffusionXLPipeline, @@ -50,7 +51,8 @@ device_put_replicated, get_flash_block_sizes, get_abstract_state, - setup_initial_state + setup_initial_state, + create_learning_rate_schedule ) from .maxdiffusion_utils import ( load_sdxllightning_unet, @@ -58,9 +60,85 @@ rescale_noise_cfg ) from .models import quantizations +from jax.tree_util import tree_flatten_with_path, tree_unflatten + cc.set_cache_dir(os.path.expanduser("~/jax_cache")) +def _get_aqt_key_paths(aqt_vars): + """Generate a list of paths which have aqt state""" + aqt_tree_flat, _ = jax.tree_util.tree_flatten_with_path(aqt_vars) + aqt_key_paths = [] + for k, _ in aqt_tree_flat: + pruned_keys = [] + for d in list(k): + if "AqtDotGeneral" in d.key: + pruned_keys.append(jax.tree_util.DictKey(key="kernel")) + break + else: + assert "Aqt" not in d.key, f"Unexpected Aqt op {d.key} in {k}." + pruned_keys.append(d) + aqt_key_paths.append(tuple(pruned_keys)) + return aqt_key_paths + +def remove_quantized_params(params, aqt_vars): + """Remove param values with aqt tensors to Null to optimize memory.""" + aqt_paths = _get_aqt_key_paths(aqt_vars) + tree_flat, tree_struct = tree_flatten_with_path(params) + for i, (k, v) in enumerate(tree_flat): + if k in aqt_paths: + v = {} + tree_flat[i] = v + return tree_unflatten(tree_struct, tree_flat) + + +def get_quantized_unet_variables(config): + + # Setup Mesh + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + batch_size = config.per_device_batch_size * jax.device_count() + + weight_dtype = get_dtype(config) + flash_block_sizes = get_flash_block_sizes(config) + + quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.CONVERT) + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=weight_dtype, + split_head_dim=config.split_head_dim, + norm_num_groups=config.norm_num_groups, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + quant=quant, + ) + + k = jax.random.key(0) + latents = jnp.ones((8, 4,128,128), dtype=jnp.float32) + timesteps = jnp.ones((8,)) + encoder_hidden_states = jnp.ones((8, 77, 2048)) + + added_cond_kwargs = { + "text_embeds": jnp.zeros((8, 1280), dtype=jnp.float32), + "time_ids": jnp.zeros((8, 6), dtype=jnp.float32), + } + noise_pred, quantized_unet_vars = pipeline.unet.apply( + params["unet"] | {"aqt" : {}}, + latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + rngs={"params": jax.random.PRNGKey(0)}, + mutable=True, + ) + del pipeline + del params + + return quantized_unet_vars + def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): latents, scheduler_state, state = args latents_input = jnp.concatenate([latents] * 2) @@ -69,12 +147,14 @@ def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, gui timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) + # breakpoint() noise_pred = model.apply( - {"params" : state.params}, + {"params" : state.params, "aqt": state.params["aqt"] }, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=prompt_embeds, - added_cond_kwargs=added_cond_kwargs + added_cond_kwargs=added_cond_kwargs, + rngs={"params": jax.random.PRNGKey(0)} ).sample noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) @@ -138,7 +218,7 @@ def tokenize(prompt, pipeline): inputs = jnp.stack(inputs,axis=1) return inputs -def run(config): +def run(config, q_v): rng = jax.random.PRNGKey(config.seed) # Setup Mesh @@ -150,7 +230,7 @@ def run(config): weight_dtype = get_dtype(config) flash_block_sizes = get_flash_block_sizes(config) - quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.CONVERT) + quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.SERVE) pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( config.pretrained_model_name_or_path, revision=config.revision, @@ -189,9 +269,24 @@ def run(config): params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) - unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, params["unet"], params["vae"], training=False) + # p = {} + # p["aqt"] = q_v["aqt"] + # # Remove param values which have corresponding qtensors in aqt to save memory. + # p["params"] = remove_quantized_params(q_v["params"], q_v["aqt"]) + learning_rate_scheduler = create_learning_rate_schedule(config) + tx = optax.adamw( + learning_rate=learning_rate_scheduler, + b1=config.adam_b1, + b2=config.adam_b2, + eps=config.adam_eps, + weight_decay=config.adam_weight_decay, + ) + unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, tx, rng, config, pipeline, q_v, params["vae"], training=False, q_v=q_v) del params["vae"] del params["unet"] + # unet_state.params = q_v + # params["unet"] = jax.tree_util.tree_map(partial_device_put_replicated, params["unet"]) + # unet_state = InferenceState(pipeline.unet.apply, params=params["unet"]) def get_unet_inputs(rng, config, batch_size, pipeline, params): vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -270,8 +365,8 @@ def get_quantized_unet_vars(unet_state, params, rng, config, batch_size, pipelin prompt_embeds=prompt_embeds, guidance_scale=guidance_scale, guidance_rescale=guidance_rescale)) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - quantized_unet_vars = loop_body_quant_p(latents=latents, scheduler_state=scheduler_state, state=unet_state,rng=rng) + # with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + quantized_unet_vars = loop_body_quant_p(latents=latents, scheduler_state=scheduler_state, state=unet_state,rng=rng) return quantized_unet_vars @@ -299,40 +394,41 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli image = vae_decode_p(latents, vae_state) return image - quantized_unet_vars = get_quantized_unet_vars(unet_state, params, rng, config, batch_size, pipeline) + #quantized_unet_vars = get_quantized_unet_vars(unet_state, params, rng, config, batch_size, pipeline) - del params - del pipeline + #del params + #del pipeline + #del unet_state + #quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.SERVE) - quant = quantizations.configure_quantization(config=config, lhs_quant_mode=aqt_flax.QuantMode.TRAIN, rhs_quant_mode=aqt_flax.QuantMode.CONVERT) - pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=weight_dtype, - split_head_dim=config.split_head_dim, - norm_num_groups=config.norm_num_groups, - attention_kernel=config.attention, - flash_block_sizes=flash_block_sizes, - mesh=mesh, - quant=quant, - ) - - scheduler_state = params.pop("scheduler") - old_params = params - params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), old_params) - params["scheduler"] = scheduler_state - - data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding)) - - sharding = PositionalSharding(devices_array).replicate() - partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) - params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) - params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) - - unet_state = InferenceState(pipeline.unet.apply, params=quantized_unet_vars) + # pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + # config.pretrained_model_name_or_path, + # revision=config.revision, + # dtype=weight_dtype, + # split_head_dim=config.split_head_dim, + # norm_num_groups=config.norm_num_groups, + # attention_kernel=config.attention, + # flash_block_sizes=flash_block_sizes, + # mesh=mesh, + # quant=quant, + # ) + + # scheduler_state = params.pop("scheduler") + # old_params = params + # params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), old_params) + # params["scheduler"] = scheduler_state + + # data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding)) + + # sharding = PositionalSharding(devices_array).replicate() + # partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) + # params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) + # params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) + + # unet_state = InferenceState(pipeline.unet.apply, params=quantized_unet_vars) # unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, quantized_unet_vars, params["vae"], training=False) - # del params["vae"] - # del params["unet"] + #del params["vae"] + #del params["unet"] p_run_inference = jax.jit( @@ -372,7 +468,12 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) - run(pyconfig.config) + q_v = get_quantized_unet_variables(pyconfig.config) + # breakpoint() + del q_v['params'] + print(q_v.keys()) + # addedkw_args...., params, aqt + run(pyconfig.config, q_v) if __name__ == "__main__": app.run(main)