diff --git a/invokeai/app/invocations/cogview4_image_to_latents.py b/invokeai/app/invocations/cogview4_image_to_latents.py index db44c6d220a..630b9ab1e3d 100644 --- a/invokeai/app/invocations/cogview4_image_to_latents.py +++ b/invokeai/app/invocations/cogview4_image_to_latents.py @@ -17,6 +17,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4 # TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We # should refactor to avoid this duplication. @@ -36,18 +37,12 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to encode.") vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) - def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int: - """Estimate the working memory required by the invocation in bytes.""" - # Encode operations use approximately 50% of the memory required for decode operations - h = image_tensor.shape[-2] - w = image_tensor.shape[-1] - element_size = next(vae.parameters()).element_size() - scaling_constant = 1100 # 50% of decode scaling constant (2200) - working_memory = h * w * element_size * scaling_constant - return int(working_memory) - @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + assert isinstance(vae_info.model, AutoencoderKL) + estimated_working_memory = estimate_vae_working_memory_cogview4( + operation="encode", image_tensor=image_tensor, vae=vae_info.model + ) with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKL) @@ -74,10 +69,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, AutoencoderKL) - estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) - latents = self.vae_encode( - vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory - ) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/cogview4_latents_to_image.py b/invokeai/app/invocations/cogview4_latents_to_image.py index 880b2f4dc22..1b77ed8a1f8 100644 --- a/invokeai/app/invocations/cogview4_latents_to_image.py +++ b/invokeai/app/invocations/cogview4_latents_to_image.py @@ -6,7 +6,6 @@ from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -20,6 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4 # TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We # should refactor to avoid this duplication. @@ -39,22 +39,15 @@ class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection) vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) - def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int: - """Estimate the working memory required by the invocation in bytes.""" - out_h = LATENT_SCALE_FACTOR * latents.shape[-2] - out_w = LATENT_SCALE_FACTOR * latents.shape[-1] - element_size = next(vae.parameters()).element_size() - scaling_constant = 2200 # Determined experimentally. - working_memory = out_h * out_w * element_size * scaling_constant - return int(working_memory) - @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, (AutoencoderKL)) - estimated_working_memory = self._estimate_working_memory(latents, vae_info.model) + estimated_working_memory = estimate_vae_working_memory_cogview4( + operation="decode", image_tensor=latents, vae=vae_info.model + ) with ( SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae), diff --git a/invokeai/app/invocations/flux_vae_decode.py b/invokeai/app/invocations/flux_vae_decode.py index e0ea8d15077..c55dfb539ac 100644 --- a/invokeai/app/invocations/flux_vae_decode.py +++ b/invokeai/app/invocations/flux_vae_decode.py @@ -3,7 +3,6 @@ from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -18,6 +17,7 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux @invocation( @@ -39,17 +39,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, ) - def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int: - """Estimate the working memory required by the invocation in bytes.""" - out_h = LATENT_SCALE_FACTOR * latents.shape[-2] - out_w = LATENT_SCALE_FACTOR * latents.shape[-1] - element_size = next(vae.parameters()).element_size() - scaling_constant = 2200 # Determined experimentally. - working_memory = out_h * out_w * element_size * scaling_constant - return int(working_memory) - def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: - estimated_working_memory = self._estimate_working_memory(latents, vae_info.model) + assert isinstance(vae_info.model, AutoEncoder) + estimated_working_memory = estimate_vae_working_memory_flux( + operation="decode", image_tensor=latents, vae=vae_info.model + ) with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoEncoder) vae_dtype = next(iter(vae.parameters())).dtype diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py index a99e39bc05f..2932517edcf 100644 --- a/invokeai/app/invocations/flux_vae_encode.py +++ b/invokeai/app/invocations/flux_vae_encode.py @@ -15,6 +15,7 @@ from invokeai.backend.model_manager import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux @invocation( @@ -35,22 +36,16 @@ class FluxVaeEncodeInvocation(BaseInvocation): input=Input.Connection, ) - def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int: - """Estimate the working memory required by the invocation in bytes.""" - # Encode operations use approximately 50% of the memory required for decode operations - h = image_tensor.shape[-2] - w = image_tensor.shape[-1] - element_size = next(vae.parameters()).element_size() - scaling_constant = 1100 # 50% of decode scaling constant (2200) - working_memory = h * w * element_size * scaling_constant - return int(working_memory) - @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: # TODO(ryand): Expose seed parameter at the invocation level. # TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes. # There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function # should be used for VAE encode sampling. + assert isinstance(vae_info.model, AutoEncoder) + estimated_working_memory = estimate_vae_working_memory_flux( + operation="encode", image_tensor=image_tensor, vae=vae_info.model + ) generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoEncoder) @@ -70,10 +65,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") context.util.signal_progress("Running VAE") - estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) - latents = self.vae_encode( - vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory - ) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 98116e2d8d4..552f5edb1b2 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -27,6 +27,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl @invocation( @@ -52,47 +53,23 @@ class ImageToLatentsInvocation(BaseInvocation): tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=False, description=FieldDescriptions.fp32) - def _estimate_working_memory( - self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny - ) -> int: - """Estimate the working memory required by the invocation in bytes.""" - # Encode operations use approximately 50% of the memory required for decode operations - element_size = 4 if self.fp32 else 2 - scaling_constant = 1100 # 50% of decode scaling constant (2200) - - if use_tiling: - tile_size = self.tile_size - if tile_size == 0: - tile_size = vae.tile_sample_min_size - assert isinstance(tile_size, int) - h = tile_size - w = tile_size - working_memory = h * w * element_size * scaling_constant - - # We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap - # and number of tiles. We could make this more precise in the future, but this should be good enough for - # most use cases. - working_memory = working_memory * 1.25 - else: - h = image_tensor.shape[-2] - w = image_tensor.shape[-1] - working_memory = h * w * element_size * scaling_constant - - if self.fp32: - # If we are running in FP32, then we should account for the likely increase in model size (~250MB). - working_memory += 250 * 2**20 - - return int(working_memory) - - @staticmethod + @classmethod def vae_encode( + cls, vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0, - estimated_working_memory: int = 0, ) -> torch.Tensor: + assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) + estimated_working_memory = estimate_vae_working_memory_sd15_sdxl( + operation="encode", + image_tensor=image_tensor, + vae=vae_info.model, + tile_size=tile_size if tiled else None, + fp32=upcast, + ) with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) orig_dtype = vae.dtype @@ -156,17 +133,13 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if image_tensor.dim() == 3: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - use_tiling = self.tiled or context.config.get().force_tiled_decode - estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model) - context.util.signal_progress("Running VAE encoder") latents = self.vae_encode( vae_info=vae_info, upcast=self.fp32, - tiled=self.tiled, + tiled=self.tiled or context.config.get().force_tiled_decode, image_tensor=image_tensor, tile_size=self.tile_size, - estimated_working_memory=estimated_working_memory, ) latents = latents.to("cpu") diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 00261bfe25a..ab1096caf7c 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -27,6 +27,7 @@ from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl @invocation( @@ -53,39 +54,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=False, description=FieldDescriptions.fp32) - def _estimate_working_memory( - self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny - ) -> int: - """Estimate the working memory required by the invocation in bytes.""" - # It was found experimentally that the peak working memory scales linearly with the number of pixels and the - # element size (precision). This estimate is accurate for both SD1 and SDXL. - element_size = 4 if self.fp32 else 2 - scaling_constant = 2200 # Determined experimentally. - - if use_tiling: - tile_size = self.tile_size - if tile_size == 0: - tile_size = vae.tile_sample_min_size - assert isinstance(tile_size, int) - out_h = tile_size - out_w = tile_size - working_memory = out_h * out_w * element_size * scaling_constant - - # We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap - # and number of tiles. We could make this more precise in the future, but this should be good enough for - # most use cases. - working_memory = working_memory * 1.25 - else: - out_h = LATENT_SCALE_FACTOR * latents.shape[-2] - out_w = LATENT_SCALE_FACTOR * latents.shape[-1] - working_memory = out_h * out_w * element_size * scaling_constant - - if self.fp32: - # If we are running in FP32, then we should account for the likely increase in model size (~250MB). - working_memory += 250 * 2**20 - - return int(working_memory) - @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) @@ -94,8 +62,13 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) - - estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model) + estimated_working_memory = estimate_vae_working_memory_sd15_sdxl( + operation="decode", + image_tensor=latents, + vae=vae_info.model, + tile_size=self.tile_size if use_tiling else None, + fp32=self.fp32, + ) with ( SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae), diff --git a/invokeai/app/invocations/sd3_image_to_latents.py b/invokeai/app/invocations/sd3_image_to_latents.py index abe37d195fc..71a48ee9ad6 100644 --- a/invokeai/app/invocations/sd3_image_to_latents.py +++ b/invokeai/app/invocations/sd3_image_to_latents.py @@ -17,6 +17,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3 @invocation( @@ -32,18 +33,12 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to encode") vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) - def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int: - """Estimate the working memory required by the invocation in bytes.""" - # Encode operations use approximately 50% of the memory required for decode operations - h = image_tensor.shape[-2] - w = image_tensor.shape[-1] - element_size = next(vae.parameters()).element_size() - scaling_constant = 1100 # 50% of decode scaling constant (2200) - working_memory = h * w * element_size * scaling_constant - return int(working_memory) - @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + assert isinstance(vae_info.model, AutoencoderKL) + estimated_working_memory = estimate_vae_working_memory_sd3( + operation="encode", image_tensor=image_tensor, vae=vae_info.model + ) with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKL) @@ -70,10 +65,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, AutoencoderKL) - estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) - latents = self.vae_encode( - vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory - ) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/sd3_latents_to_image.py b/invokeai/app/invocations/sd3_latents_to_image.py index 794464b97f4..e6a20d38a9c 100644 --- a/invokeai/app/invocations/sd3_latents_to_image.py +++ b/invokeai/app/invocations/sd3_latents_to_image.py @@ -6,7 +6,6 @@ from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -20,6 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3 @invocation( @@ -41,22 +41,15 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, ) - def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int: - """Estimate the working memory required by the invocation in bytes.""" - out_h = LATENT_SCALE_FACTOR * latents.shape[-2] - out_w = LATENT_SCALE_FACTOR * latents.shape[-1] - element_size = next(vae.parameters()).element_size() - scaling_constant = 2200 # Determined experimentally. - working_memory = out_h * out_w * element_size * scaling_constant - return int(working_memory) - @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, (AutoencoderKL)) - estimated_working_memory = self._estimate_working_memory(latents, vae_info.model) + estimated_working_memory = estimate_vae_working_memory_sd3( + operation="decode", image_tensor=latents, vae=vae_info.model + ) with ( SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae), diff --git a/invokeai/backend/util/vae_working_memory.py b/invokeai/backend/util/vae_working_memory.py new file mode 100644 index 00000000000..7259237f568 --- /dev/null +++ b/invokeai/backend/util/vae_working_memory.py @@ -0,0 +1,117 @@ +from typing import Literal + +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.backend.flux.modules.autoencoder import AutoEncoder + + +def estimate_vae_working_memory_sd15_sdxl( + operation: Literal["encode", "decode"], + image_tensor: torch.Tensor, + vae: AutoencoderKL | AutoencoderTiny, + tile_size: int | None, + fp32: bool, +) -> int: + """Estimate the working memory required to encode or decode the given tensor.""" + # It was found experimentally that the peak working memory scales linearly with the number of pixels and the + # element size (precision). This estimate is accurate for both SD1 and SDXL. + element_size = 4 if fp32 else 2 + + # This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414 + # Encoding uses ~45% the working memory as decoding. + scaling_constant = 2200 if operation == "decode" else 1100 + + latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1 + + if tile_size is not None: + if tile_size == 0: + tile_size = vae.tile_sample_min_size + assert isinstance(tile_size, int) + h = tile_size + w = tile_size + working_memory = h * w * element_size * scaling_constant + + # We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap + # and number of tiles. We could make this more precise in the future, but this should be good enough for + # most use cases. + working_memory = working_memory * 1.25 + else: + h = latent_scale_factor_for_operation * image_tensor.shape[-2] + w = latent_scale_factor_for_operation * image_tensor.shape[-1] + working_memory = h * w * element_size * scaling_constant + + if fp32: + # If we are running in FP32, then we should account for the likely increase in model size (~250MB). + working_memory += 250 * 2**20 + + print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}") + + return int(working_memory) + + +def estimate_vae_working_memory_cogview4( + operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL +) -> int: + """Estimate the working memory required by the invocation in bytes.""" + latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1 + + h = latent_scale_factor_for_operation * image_tensor.shape[-2] + w = latent_scale_factor_for_operation * image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + + # This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414 + # Encoding uses ~45% the working memory as decoding. + scaling_constant = 2200 if operation == "decode" else 1100 + working_memory = h * w * element_size * scaling_constant + + print(f"estimate_vae_working_memory_cogview4: {int(working_memory)}") + + return int(working_memory) + + +def estimate_vae_working_memory_flux( + operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoEncoder +) -> int: + """Estimate the working memory required by the invocation in bytes.""" + + latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1 + + out_h = latent_scale_factor_for_operation * image_tensor.shape[-2] + out_w = latent_scale_factor_for_operation * image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + + # This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414 + # Encoding uses ~45% the working memory as decoding. + scaling_constant = 2200 if operation == "decode" else 1100 + + working_memory = out_h * out_w * element_size * scaling_constant + + print(f"estimate_vae_working_memory_flux: {int(working_memory)}") + + return int(working_memory) + + +def estimate_vae_working_memory_sd3( + operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL +) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + + latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1 + + h = latent_scale_factor_for_operation * image_tensor.shape[-2] + w = latent_scale_factor_for_operation * image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + + # This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414 + # Encoding uses ~45% the working memory as decoding. + scaling_constant = 2200 if operation == "decode" else 1100 + + working_memory = h * w * element_size * scaling_constant + + print(f"estimate_vae_working_memory_sd3: {int(working_memory)}") + + return int(working_memory)