Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jul 16, 2024

Stack from ghstack (oldest at bottom):

Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 16, 2024
fegin added a commit that referenced this pull request Jul 16, 2024
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

ghstack-source-id: a57eb56
Pull Request resolved: #463
@fegin fegin marked this pull request as draft July 16, 2024 23:07
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 16, 2024
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

ghstack-source-id: 8f42a91
Pull Request resolved: #463
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 17, 2024
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

ghstack-source-id: 7edd3ec
Pull Request resolved: #463
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 17, 2024
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

ghstack-source-id: fb14c05
Pull Request resolved: #463
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 17, 2024
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

ghstack-source-id: 189ab55
Pull Request resolved: #463
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
cp_mesh = world_mesh["cp"]
dp_mesh = world_mesh["dp"]
cp_mesh = dp_mesh.reshape((-1, parallel_dims.cp), ("dp", "cp"))["cp"]
Copy link
Contributor Author

@fegin fegin Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may extend ParallelDims to have helpers to get different submeshs. However, the naming can be tricky as there are multiple meaning of dp_mesh.

assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
dp_mesh = dp_mesh.reshape(
(parallel_dims.dp_replicate, -1),
("dp_replicate", "dp_shard"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For HSDP + CP, suppose it's a (2, 2, 2) mesh where HSDP + CP applied, how could CP be used together with HSDP? is CP somewhat need to be merge into one of the HSDP dimensions? i.e. HSDP mesh would be (2, 2), and after merging CP it would become (2, 4) or (4, 2)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be (2, 4) for fully_shard and it would be (4,) (the merged dimension of first two dimensions of the world_mesh) for data loader.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jul 18, 2024
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

ghstack-source-id: 5f06d72
Pull Request resolved: #463
@fegin fegin closed this Aug 13, 2024
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP

ghstack-source-id: 5f06d72
Pull Request resolved: #463
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants