Skip to content

Commit 0003397

Browse files
authored
Merge pull request #7 from pytorch-tpu/piz/ft
Use flash_attention for SD2
2 parents 336f07c + 17caea4 commit 0003397

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

examples/text_to_image/train_text_to_image_xla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def main(args):
385385
server = xp.start_server(9012)
386386

387387
num_devices = xr.global_runtime_device_count()
388-
mesh = xs.get_1d_mesh('x')
388+
mesh = xs.get_1d_mesh('data')
389389
xs.set_global_mesh(mesh)
390390

391391
text_encoder = CLIPTextModel.from_pretrained(
@@ -521,9 +521,9 @@ def collate_fn(examples):
521521
device,
522522
input_sharding={
523523
"pixel_values": xs.ShardingSpec(
524-
mesh, ("x", None, None, None), minibatch=True
524+
mesh, ("data", None, None, None), minibatch=True
525525
),
526-
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
526+
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
527527
},
528528
loader_prefetch_size=args.loader_prefetch_size,
529529
device_prefetch_size=args.device_prefetch_size,

src/diffusers/models/attention_processor.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
from torch import nn
2121

2222
from ..image_processor import IPAdapterMaskProcessor
23-
from ..utils import deprecate, logging
23+
from ..utils import deprecate, logging, is_torch_xla_available
2424
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

27+
if is_torch_xla_available():
28+
from torch_xla.experimental.custom_kernel import flash_attention
29+
XLA_AVAILABLE = True
30+
else:
31+
XLA_AVAILABLE = False
2732

2833
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2934

@@ -2364,9 +2369,21 @@ def __call__(
23642369

23652370
# the output of sdp = (batch, num_heads, seq_len, head_dim)
23662371
# TODO: add support for attn.scale when we move to Torch 2.1
2367-
hidden_states = F.scaled_dot_product_attention(
2368-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2369-
)
2372+
2373+
if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
2374+
if attention_mask is not None:
2375+
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
2376+
# Convert mask to float and replace 0s with -inf and 1s with 0
2377+
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
2378+
2379+
# Apply attention mask to key
2380+
key = key + attention_mask
2381+
query /= math.sqrt(query.shape[3])
2382+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None))
2383+
else:
2384+
hidden_states = F.scaled_dot_product_attention(
2385+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2386+
)
23702387

23712388
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
23722389
hidden_states = hidden_states.to(query.dtype)

0 commit comments

Comments
 (0)