Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions torchtitan/experiments/simple_fsdp/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,45 @@ def parallelize_llama(
)
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
if job_config.compile.enable:
from functools import partial
bucket_level = ""
torch._inductor.config.run_with_post_grad_graph = True
if bucket_level == "inductor":
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
from autoparallel.auto_bucketing import (
simple_fsdp_autobucketing_reordering_pass,
simplefsdp_autobucketing_config,
)

torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = True
simplefsdp_autobucketing_config.save_estimation_path = (
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
)
simplefsdp_autobucketing_config.calibrate_number = 20
simple_fsdp_autobucketing_reordering_pass = partial(
simple_fsdp_autobucketing_reordering_pass,
configs=simplefsdp_autobucketing_config,
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
simple_fsdp_autobucketing_reordering_pass
]

# Don't use both sets of passes at the same time!
torch._inductor.config.bucket_all_gathers_fx = "none"
torch._inductor.config.bucket_reduce_scatters_fx = "none"
elif bucket_level == "aten":
from autoparallel.auto_bucketing import aten_autobucketing_reordering_pass, aten_autobucketing_config
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
aten_autobucketing_reordering_pass = partial(
aten_autobucketing_reordering_pass,
configs=aten_autobucketing_config,
)
torch._inductor.config.post_grad_custom_post_pass = aten_autobucketing_reordering_pass

model = torch.compile(model, fullgraph=True)

return model
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ local_batch_size = 1
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 1000
dataset = "c4"
dataset = "c4_test"

[parallelism]
data_parallel_replicate_degree = 1
Expand Down
Loading