Skip to content

Commit 7157227

Browse files
committed
apply partition_spec
1 parent 57c8b72 commit 7157227

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

examples/text_to_image/train_text_to_image_xla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def main(args):
385385
server = xp.start_server(9012)
386386

387387
num_devices = xr.global_runtime_device_count()
388-
mesh = xs.get_1d_mesh('x')
388+
mesh = xs.get_1d_mesh('data')
389389
xs.set_global_mesh(mesh)
390390

391391
text_encoder = CLIPTextModel.from_pretrained(
@@ -521,9 +521,9 @@ def collate_fn(examples):
521521
device,
522522
input_sharding={
523523
"pixel_values": xs.ShardingSpec(
524-
mesh, ("x", None, None, None), minibatch=True
524+
mesh, ("data", None, None, None), minibatch=True
525525
),
526-
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
526+
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
527527
},
528528
loader_prefetch_size=args.loader_prefetch_size,
529529
device_prefetch_size=args.device_prefetch_size,

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

2727
if is_torch_xla_available():
28+
import torch_xla.distributed.spmd as xs
2829
from torch_xla.experimental.custom_kernel import flash_attention
2930
XLA_AVAILABLE = True
3031
else:
@@ -2379,7 +2380,7 @@ def __call__(
23792380
# Apply attention mask to key
23802381
key = key + attention_mask
23812382
query /= math.sqrt(query.shape[3])
2382-
hidden_states = flash_attention(query, key, value, causal=False)
2383+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None))
23832384
else:
23842385
hidden_states = F.scaled_dot_product_attention(
23852386
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False

0 commit comments

Comments
 (0)