Skip to content

Use flash_attention for SD2 #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/text_to_image/train_text_to_image_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def main(args):
server = xp.start_server(9012)

num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')
mesh = xs.get_1d_mesh('data')
xs.set_global_mesh(mesh)

text_encoder = CLIPTextModel.from_pretrained(
Expand Down Expand Up @@ -521,9 +521,9 @@ def collate_fn(examples):
device,
input_sharding={
"pixel_values": xs.ShardingSpec(
mesh, ("x", None, None, None), minibatch=True
mesh, ("data", None, None, None), minibatch=True
),
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
},
loader_prefetch_size=args.loader_prefetch_size,
device_prefetch_size=args.device_prefetch_size,
Expand Down
25 changes: 21 additions & 4 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@
from torch import nn

from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging
from ..utils import deprecate, logging, is_torch_xla_available
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph

if is_torch_xla_available():
from torch_xla.experimental.custom_kernel import flash_attention
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -2364,9 +2369,21 @@ def __call__(

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
# Convert mask to float and replace 0s with -inf and 1s with 0
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))

# Apply attention mask to key
key = key + attention_mask
query /= math.sqrt(query.shape[3])
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None))
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
Expand Down
Loading