Skip to content

Conversation

SherlockNoMad
Copy link

@SherlockNoMad SherlockNoMad commented Oct 3, 2025

This is an e2e prototype to run llama3-simplefsdp using export-y aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP

  • [Done] Start with a simpleFSDP model, enable TP + FSDP
  • [Done] Apply aot_export_joing_with_descriptor on parallelized module with DTensor input to get the joint graph
  • [Done] Apply min_cut_partitioner to get forward and backward graph module
  • [Done but Need verification] Apply prefect/bucketing graph passes on fw_gm and bw_gm to reorder/group the communication collectives
  • [Done] Run the joint graph with aot_compile_joint_with_descriptors

Issues

Repro steps:
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name joint_graph_runner.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name joint_graph_runner.llama3 --compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 3, 2025
@SherlockNoMad SherlockNoMad changed the title Joint Graph Runner JointGraph-based Training Prototype Oct 3, 2025
@SherlockNoMad SherlockNoMad marked this pull request as ready for review October 9, 2025 00:31
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for exploration purpose? If so I'd suggest we work in a branch / fork.

Comment on lines +296 to +307
# Hack: convert args and kwargs to DTensor. This should be fixed at data loader.
# This works, but kinda cheating?
dt_args = tuple(
DTensor.from_local(arg, self.parallel_dims.world_mesh["tp"], [Replicate()])
for arg in args
)

# RuntimeError('Sharding propagation failed for Op(op=aten.embedding.default, args_schema=Spec(S(0) on (2048, 256)), Spec((Shard(dim=0), Replicate()) on (16, 2048)) @ mesh: (2, 4))')
# dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh, [Shard(0), Replicate()]) for arg in args)

# RuntimeError('Sharding propagation failed for Op(op=aten.embedding.default, args_schema=Spec(S(0) on (2048, 256)), Spec(S(0) on (16, 2048)) @ mesh: (2,))')
# dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh["dp_shard"], [Shard(0)]) for arg in args)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants