-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Brandon/flux model loading #6739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
113 commits
Select commit
Hold shift + click to select a range
5256f26
Bump diffusers version to include FLUX support.
RyanJDick 562c2cc
Update imports for compatibility with bumped diffusers version.
RyanJDick b617631
Update HF download logic to work for black-forest-labs/FLUX.1-schnell.
RyanJDick 6a068cc
First draft of FluxTextToImageInvocation.
RyanJDick 5149a3e
Add sentencepiece dependency for the T5 tokenizer.
RyanJDick 761e49f
Use the FluxPipeline.encode_prompt() api rather than trying to run th…
RyanJDick 5e1b3e9
Got FLUX schnell working with 8-bit quantization. Still lots of rough…
RyanJDick e71c7d0
Minor improvements to FLUX workflow.
RyanJDick f7753be
Make 8-bit quantization save/reload work for the FLUX transformer. Re…
RyanJDick 28654ec
Add support for 8-bit quantizatino of the FLUX T5XXL text encoder.
RyanJDick 40e9a4e
Make float16 inference work with FLUX on 24GB GPU.
RyanJDick df1ac07
WIP - experimentation
RyanJDick 11279ab
Make quantized loading fast.
RyanJDick bcb7f8e
Make quantized loading fast for both T5XXL and FLUX transformer.
RyanJDick 3f1fbc6
Split a FluxTextEncoderInvocation out from the FluxTextToImageInvocat…
RyanJDick 06d35c3
wip
RyanJDick b1cf2f5
NF4 loading working... I think.
RyanJDick 31c8d76
NF4 inference working
RyanJDick 17f5952
Clean up NF4 implementation.
RyanJDick 52ff3c7
LLM.int8() quantization is working, but still some rough edges to solve.
RyanJDick 307a130
More improvements for LLM.int8() - not fully tested.
RyanJDick 27f33bb
WIP on moving from diffusers to FLUX
RyanJDick e157ff3
Setup flux model loading in the UI
brandonrising 6779e03
Remove changes to v1 workflow
brandonrising 4bb3438
Manage quantization of models within the loader
brandonrising 6b0f5f4
Run Ruff
brandonrising 3814cd7
Run Ruff
brandonrising a6ad70e
Some UI cleanup, regenerate schema
brandonrising f3096a8
Add backend functions and classes for Flux implementation, Update the…
brandonrising 5fc6c28
Run ruff, setup initial text to image node
brandonrising 68d28db
Add nf4 bnb quantized format
brandonrising f3ebbe1
Remove unused param on _run_vae_decoding in flux text to image
brandonrising efab4a3
Working inference node with quantized bnb nf4 checkpoint
brandonrising 95a2d97
Install sub directories with folders correctly, ensure consistent dty…
brandonrising 98151ce
Select dev/schnell based on state dict, use correct max seq len based…
brandonrising 5d7e154
Fix FLUX output image clamping. And a few other minor fixes to make i…
RyanJDick 870ecd3
Add tqdm progress bar to FLUX denoising.
RyanJDick 24829b9
Fix support for 8b quantized t5 encoders, update exception messages i…
brandonrising f36c6d0
Fix styling/lint
brandonrising bebc6d3
Add t5 encoders and clip embeds to the model manager
brandonrising 3f845d9
Some cleanup of the tags and description of flux nodes
brandonrising 8b3e386
exclude flux models from main model dropdown
35c263a
add default workflow for flux t2i
ec360ee
Rename t5Encoder -> t5_encoder.
RyanJDick c822c3d
Address minor review comments.
RyanJDick 19238ed
Update doc string for import_local_model and remove access_token sinc…
brandonrising ede26a7
Switch inheritance class of flux model loaders
brandonrising f0408bb
Various styling and exception type updates
brandonrising 9e888b1
More flux loader cleanup
brandonrising 6afb113
Remove duplicate log_time(...) function.
RyanJDick 519bf71
Add docs to the requantize(...) function explaining why it was copied…
RyanJDick c549a49
Move requantize.py to the quatnization/ dir.
RyanJDick 41fb09b
update flux_model_loader node to take a T5 encoder from node field in…
maryhipp c66ccad
add case for clip embed models in probe
maryhipp 9020a8a
add FLUX schnell starter models and submodels as dependenices or adho…
maryhipp 7264920
fix(ui): only exclude flux main models from linear UI dropdown, not m…
3fe9582
fix(ui): pass base/type when installing models, add flux formats to M…
24831d4
feat(ui): create new field for t5 encoder models in nodes
6899762
tsc and lint fix
f67c4da
fix schema
a04d479
update default workflow
maryhipp 192eda7
fix(worker) fix T5 type
maryhipp 3c861fd
add better workflow description
maryhipp dcfdc00
add better workflow name
maryhipp f51dd36
Fix bug in InvokeInt8Params that was causing it to use double the nec…
RyanJDick 3c9811f
Update load_flux_model_bnb_llm_int8.py to work with a single-file FLU…
RyanJDick 9982bc2
Add docs to the quantization scripts.
RyanJDick c5c60f5
Fix max_seq_len field description.
RyanJDick 5f3e325
Remove automatic install of models during flux model loader, remove n…
brandonrising bc6e1ba
Run ruff
brandonrising 407796c
Undo changes to the v2 dir of frontend types
brandonrising d4ec434
added FLUX dev to starter models
maryhipp 374dc82
Don't install bitsandbytes on macOS
brandonrising 80a46d2
Attribute black-forest-labs/flux for much of the flux code
brandonrising 5406a2f
Mark FLUX nodes as prototypes.
RyanJDick c9c4e47
Make FLUX get_noise(...) consistent across devices/dtypes.
RyanJDick 22a3b3d
Tidy is_schnell detection logic.
RyanJDick 5307a6f
Add comment about incorrect T5 Tokenizer size calculation.
RyanJDick 5e9ef4b
Rename field positive_prompt -> prompt.
RyanJDick 0e9f6f7
Move prepare_latent_image_patches(...) to sampling.py with all of the…
RyanJDick b5c937e
Run FLUX VAE decoding in the user's preferred dtype rather than float…
RyanJDick f34a923
Update macos test vm to macOS-14
brandonrising b8d4630
Load and unload clip/t5 encoders and run inference separately in text…
brandonrising a31c02b
Only import bnb quantize file if bitsandbytes is installed
brandonrising 9899e42
Switch flux to using its own conditioning field
brandonrising cfcd860
Add script for quantizing a T5 model.
RyanJDick 4089ff2
Fixes to the T5XXL quantization script.
RyanJDick 54c48c3
Update the T5 8-bit quantized starter model to use the BnB LLM.int8()…
RyanJDick 5af214b
Remove all references to optimum-quanto and downgrade diffusers.
RyanJDick f4612a9
Update docs for T5 quantization script.
RyanJDick 18c0ec3
Move quantization scripts to a scripts/ subdir.
RyanJDick 098db5c
Downgrade revert torch version after removing optimum-qanto, and othe…
RyanJDick dbdd851
Update t5 encoder formats to accurately reflect the quantization stra…
brandonrising 1d6c83b
Switch the CLIP-L start model to use our hosted version - which is mu…
RyanJDick 1413ff9
Replace swish() with torch.nn.functional.silu(h). They are functional…
RyanJDick bd1b37d
Setup scaffolding for in progress images and add ability to cancel th…
brandonrising d159fe6
Remove dependency on flux config files
brandonrising 877b88e
ruff
RyanJDick ae94e48
Remove flux repo dependency
RyanJDick f046a38
Downgrade accelerate and huggingface-hub deps to original versions.
RyanJDick 9f6f404
ruff format
RyanJDick 9a530c7
Remove outdated TODO.
RyanJDick bb80697
Only install starter models if not already installed
brandonrising 642a953
Remove in progress images until we're able to make the valuable
brandonrising a90d098
Remove no longer used code in the flux denoise function
brandonrising 40a3fa5
Fix type error in tsc
brandonrising b9238b6
Run ruff
brandonrising 5a5ca10
Rename params for flux and flux vae, add comments explaining use of t…
brandonrising bf59ab3
update default workflow for flux
bd2692b
remove prompt
5d42e67
Run ruff
brandonrising c510234
default workflow: add steps to exposed fields, add more notes
3b29bad
Update starter model size estimates.
RyanJDick File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Literal | ||
|
||
import torch | ||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer | ||
|
||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation | ||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField | ||
from invokeai.app.invocations.model import CLIPField, T5EncoderField | ||
from invokeai.app.invocations.primitives import FluxConditioningOutput | ||
from invokeai.app.services.shared.invocation_context import InvocationContext | ||
from invokeai.backend.flux.modules.conditioner import HFEncoder | ||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo | ||
|
||
|
||
@invocation( | ||
"flux_text_encoder", | ||
title="FLUX Text Encoding", | ||
tags=["prompt", "conditioning", "flux"], | ||
category="conditioning", | ||
version="1.0.0", | ||
classification=Classification.Prototype, | ||
) | ||
class FluxTextEncoderInvocation(BaseInvocation): | ||
"""Encodes and preps a prompt for a flux image.""" | ||
|
||
clip: CLIPField = InputField( | ||
title="CLIP", | ||
description=FieldDescriptions.clip, | ||
input=Input.Connection, | ||
) | ||
t5_encoder: T5EncoderField = InputField( | ||
title="T5Encoder", | ||
description=FieldDescriptions.t5_encoder, | ||
input=Input.Connection, | ||
) | ||
t5_max_seq_len: Literal[256, 512] = InputField( | ||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models." | ||
) | ||
prompt: str = InputField(description="Text prompt to encode.") | ||
|
||
@torch.no_grad() | ||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput: | ||
t5_embeddings, clip_embeddings = self._encode_prompt(context) | ||
conditioning_data = ConditioningFieldData( | ||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] | ||
) | ||
|
||
conditioning_name = context.conditioning.save(conditioning_data) | ||
return FluxConditioningOutput.build(conditioning_name) | ||
|
||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: | ||
# Load CLIP. | ||
clip_tokenizer_info = context.models.load(self.clip.tokenizer) | ||
clip_text_encoder_info = context.models.load(self.clip.text_encoder) | ||
|
||
# Load T5. | ||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) | ||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) | ||
|
||
prompt = [self.prompt] | ||
|
||
with ( | ||
t5_text_encoder_info as t5_text_encoder, | ||
t5_tokenizer_info as t5_tokenizer, | ||
): | ||
assert isinstance(t5_text_encoder, T5EncoderModel) | ||
assert isinstance(t5_tokenizer, T5Tokenizer) | ||
|
||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) | ||
|
||
prompt_embeds = t5_encoder(prompt) | ||
|
||
with ( | ||
clip_text_encoder_info as clip_text_encoder, | ||
clip_tokenizer_info as clip_tokenizer, | ||
): | ||
assert isinstance(clip_text_encoder, CLIPTextModel) | ||
assert isinstance(clip_tokenizer, CLIPTokenizer) | ||
|
||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77) | ||
|
||
pooled_prompt_embeds = clip_encoder(prompt) | ||
|
||
assert isinstance(prompt_embeds, torch.Tensor) | ||
assert isinstance(pooled_prompt_embeds, torch.Tensor) | ||
return prompt_embeds, pooled_prompt_embeds |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import torch | ||
from einops import rearrange | ||
from PIL import Image | ||
|
||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation | ||
from invokeai.app.invocations.fields import ( | ||
FieldDescriptions, | ||
FluxConditioningField, | ||
Input, | ||
InputField, | ||
WithBoard, | ||
WithMetadata, | ||
) | ||
from invokeai.app.invocations.model import TransformerField, VAEField | ||
from invokeai.app.invocations.primitives import ImageOutput | ||
from invokeai.app.services.session_processor.session_processor_common import CanceledException | ||
from invokeai.app.services.shared.invocation_context import InvocationContext | ||
from invokeai.backend.flux.model import Flux | ||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder | ||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack | ||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo | ||
from invokeai.backend.util.devices import TorchDevice | ||
|
||
|
||
@invocation( | ||
"flux_text_to_image", | ||
title="FLUX Text to Image", | ||
tags=["image", "flux"], | ||
category="image", | ||
version="1.0.0", | ||
classification=Classification.Prototype, | ||
) | ||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): | ||
"""Text-to-image generation using a FLUX model.""" | ||
|
||
transformer: TransformerField = InputField( | ||
description=FieldDescriptions.flux_model, | ||
input=Input.Connection, | ||
title="Transformer", | ||
) | ||
vae: VAEField = InputField( | ||
description=FieldDescriptions.vae, | ||
input=Input.Connection, | ||
) | ||
positive_text_conditioning: FluxConditioningField = InputField( | ||
description=FieldDescriptions.positive_cond, input=Input.Connection | ||
) | ||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") | ||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") | ||
num_steps: int = InputField( | ||
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50." | ||
) | ||
guidance: float = InputField( | ||
default=4.0, | ||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.", | ||
) | ||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.") | ||
|
||
@torch.no_grad() | ||
def invoke(self, context: InvocationContext) -> ImageOutput: | ||
# Load the conditioning data. | ||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) | ||
assert len(cond_data.conditionings) == 1 | ||
flux_conditioning = cond_data.conditionings[0] | ||
assert isinstance(flux_conditioning, FLUXConditioningInfo) | ||
|
||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) | ||
image = self._run_vae_decoding(context, latents) | ||
image_dto = context.images.save(image=image) | ||
return ImageOutput.build(image_dto) | ||
|
||
def _run_diffusion( | ||
self, | ||
context: InvocationContext, | ||
clip_embeddings: torch.Tensor, | ||
t5_embeddings: torch.Tensor, | ||
): | ||
transformer_info = context.models.load(self.transformer.transformer) | ||
inference_dtype = torch.bfloat16 | ||
RyanJDick marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Prepare input noise. | ||
x = get_noise( | ||
num_samples=1, | ||
height=self.height, | ||
width=self.width, | ||
device=TorchDevice.choose_torch_device(), | ||
dtype=inference_dtype, | ||
seed=self.seed, | ||
) | ||
|
||
img, img_ids = prepare_latent_img_patches(x) | ||
|
||
is_schnell = "schnell" in transformer_info.config.config_path | ||
|
||
timesteps = get_schedule( | ||
num_steps=self.num_steps, | ||
image_seq_len=img.shape[1], | ||
shift=not is_schnell, | ||
) | ||
|
||
bs, t5_seq_len, _ = t5_embeddings.shape | ||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) | ||
|
||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from | ||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems | ||
# if the cache is not empty. | ||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) | ||
|
||
with transformer_info as transformer: | ||
assert isinstance(transformer, Flux) | ||
|
||
def step_callback() -> None: | ||
if context.util.is_canceled(): | ||
raise CanceledException | ||
|
||
# TODO: Make this look like the image before re-enabling | ||
# latent_image = unpack(img.float(), self.height, self.width) | ||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions | ||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] | ||
|
||
# # Create a new tensor of the required shape [255, 255, 3] | ||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format | ||
|
||
# # Convert to a NumPy array and then to a PIL Image | ||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) | ||
|
||
# (width, height) = image.size | ||
# width *= 8 | ||
# height *= 8 | ||
|
||
# dataURL = image_to_dataURL(image, image_format="JPEG") | ||
|
||
# # TODO: move this whole function to invocation context to properly reference these variables | ||
# context._services.events.emit_invocation_denoise_progress( | ||
# context._data.queue_item, | ||
# context._data.invocation, | ||
# state, | ||
# ProgressImage(dataURL=dataURL, width=width, height=height), | ||
# ) | ||
|
||
x = denoise( | ||
model=transformer, | ||
img=img, | ||
img_ids=img_ids, | ||
txt=t5_embeddings, | ||
txt_ids=txt_ids, | ||
vec=clip_embeddings, | ||
timesteps=timesteps, | ||
step_callback=step_callback, | ||
guidance=self.guidance, | ||
) | ||
|
||
x = unpack(x.float(), self.height, self.width) | ||
|
||
return x | ||
|
||
def _run_vae_decoding( | ||
self, | ||
context: InvocationContext, | ||
latents: torch.Tensor, | ||
) -> Image.Image: | ||
vae_info = context.models.load(self.vae.vae) | ||
with vae_info as vae: | ||
assert isinstance(vae, AutoEncoder) | ||
latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) | ||
img = vae.decode(latents) | ||
|
||
img = img.clamp(-1, 1) | ||
img = rearrange(img[0], "c h w -> h w c") | ||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) | ||
|
||
return img_pil |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.