Skip to content

Commit 17caea4

Browse files
committed
apply partition spec
1 parent 57c8b72 commit 17caea4

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/text_to_image/train_text_to_image_xla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def main(args):
385385
server = xp.start_server(9012)
386386

387387
num_devices = xr.global_runtime_device_count()
388-
mesh = xs.get_1d_mesh('x')
388+
mesh = xs.get_1d_mesh('data')
389389
xs.set_global_mesh(mesh)
390390

391391
text_encoder = CLIPTextModel.from_pretrained(
@@ -521,9 +521,9 @@ def collate_fn(examples):
521521
device,
522522
input_sharding={
523523
"pixel_values": xs.ShardingSpec(
524-
mesh, ("x", None, None, None), minibatch=True
524+
mesh, ("data", None, None, None), minibatch=True
525525
),
526-
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
526+
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
527527
},
528528
loader_prefetch_size=args.loader_prefetch_size,
529529
device_prefetch_size=args.device_prefetch_size,

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2379,7 +2379,7 @@ def __call__(
23792379
# Apply attention mask to key
23802380
key = key + attention_mask
23812381
query /= math.sqrt(query.shape[3])
2382-
hidden_states = flash_attention(query, key, value, causal=False)
2382+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None))
23832383
else:
23842384
hidden_states = F.scaled_dot_product_attention(
23852385
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False

0 commit comments

Comments
 (0)