Skip to content

Commit 3646ad8

Browse files
committed
Add support for AC budget API
1 parent eb13ba2 commit 3646ad8

File tree

7 files changed

+61
-11
lines changed

7 files changed

+61
-11
lines changed

torchtitan/config/job_config.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ class Checkpoint:
538538

539539
@dataclass
540540
class ActivationCheckpoint:
541-
mode: Literal["selective", "full", "none"] = "selective"
541+
mode: Literal["selective", "full", "memory_budget", "none"] = "selective"
542542
"""Type of activation checkpointing to use"""
543543

544544
selective_ac_option: str = "2"
@@ -566,6 +566,35 @@ class ActivationCheckpoint:
566566
Whether to stop recomputing early when all activations have already been
567567
rematerialized.
568568
"""
569+
activation_memory_budget: float = 1.0
570+
"""
571+
When mode is set to "memory_budget", this value determines how much
572+
partitioner in the compiler should trade off compute for memory.
573+
0.0 corresponds to the activation memory from applying
574+
activation checkpointing to the full compiled region, and 1.0 corresponds to
575+
the activation memory from the default runtime-optimized strategy.
576+
"""
577+
activation_memory_budget_runtime_estimator: Literal["flops", "profile"] = "flops"
578+
"""
579+
This controls how we estimate the runtime when deciding what the cheapest
580+
operators to recompute are. The 3 options are
581+
"flops": Bases it off of the flop count provided by torch.utils.flop_counter
582+
"profile": Benchmarks each operator to come up with a runtime
583+
"testing": Returns 1 for everything
584+
"""
585+
activation_memory_budget_solver: Literal["dp", "greedy", "ilp"] = "dp"
586+
"""
587+
This controls the solver used for the 0-1 knapsack. By default we use a
588+
quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp"
589+
(which has a scipy dependency).
590+
"""
591+
visualize_memory_budget_pareto: bool = False
592+
"""
593+
This dumps out a SVG visualization of the expected runtime vs. activation
594+
memory tradeoffs for all memory budget values from 0 to 1 in increments of
595+
0.5. See an example here:
596+
https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
597+
"""
569598

570599

571600
@dataclass

torchtitan/distributed/activation_checkpoint.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# This file provides the util functions to apply activation checkpointing to the model.
88
# Technically, this is not a part of distributed, but distributed module is the best place to put it.
99

10+
import os
1011
from collections import defaultdict
1112

1213
import torch
@@ -279,6 +280,7 @@ def apply_ac(
279280
model_compile_enabled: bool = False,
280281
use_flex_attn: bool = False,
281282
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
283+
base_folder: str = "",
282284
) -> None:
283285
"""Apply activation checkpointing to the model.
284286
@@ -297,15 +299,29 @@ def apply_ac(
297299
None
298300
"""
299301

300-
for layer_id, transformer_block in model.layers.named_children():
301-
transformer_block = _apply_ac_to_transformer_block(
302-
transformer_block,
303-
ac_config,
304-
base_fqn=f"layers.{layer_id}",
305-
model_compile_enabled=model_compile_enabled,
306-
use_flex_attn=use_flex_attn,
307-
op_sac_save_list=op_sac_save_list,
308-
)
309-
model.layers.register_module(layer_id, transformer_block)
302+
if ac_config.mode == "memory_budget":
303+
assert (model_compile_enabled is True), "Memory budget mode requires model to be compiled"
304+
torch._functorch.config.activation_memory_budget_solver = ac_config.activation_memory_budget_solver
305+
torch._functorch.config.activation_memory_budget_runtime_estimator = ac_config.activation_memory_budget_runtime_estimator
306+
if ac_config.visualize_memory_budget_pareto:
307+
pareto_dir = os.path.join(base_folder, "memory_budget_pareto")
308+
if not os.path.exists(pareto_dir):
309+
os.makedirs(pareto_dir, exist_ok=True)
310+
torch._functorch.config.memory_budget_pareto_dir = pareto_dir
311+
torch._functorch.config.visualize_memory_budget_pareto = True
312+
313+
torch._functorch.config.activation_memory_budget = ac_config.activation_memory_budget
314+
logger.info(f"Selected {ac_config.activation_memory_budget} budget option")
315+
else:
316+
for layer_id, transformer_block in model.layers.named_children():
317+
transformer_block = _apply_ac_to_transformer_block(
318+
transformer_block,
319+
ac_config,
320+
base_fqn=f"layers.{layer_id}",
321+
model_compile_enabled=model_compile_enabled,
322+
use_flex_attn=use_flex_attn,
323+
op_sac_save_list=op_sac_save_list,
324+
)
325+
model.layers.register_module(layer_id, transformer_block)
310326

311327
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def parallelize_llama(
120120
model_compile_enabled=model_compile_enabled,
121121
use_flex_attn=use_flex_attn,
122122
op_sac_save_list=_op_sac_save_list,
123+
base_folder=job_config.job.dump_folder,
123124
)
124125

125126
# turn on per-TransformerBlock compile after AC wrapping and before FSDP

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def parallelize_qwen3(
114114
model_compile_enabled=model_compile_enabled,
115115
use_flex_attn=use_flex_attn,
116116
op_sac_save_list=_op_sac_save_list,
117+
base_folder=job_config.job.dump_folder,
117118
)
118119

119120
# turn on per-TransformerBlock compile after AC wrapping and before FSDP

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def parallelize_llama(
8585
model_compile_enabled=model_compile_enabled,
8686
use_flex_attn=use_flex_attn,
8787
op_sac_save_list=_op_sac_save_list,
88+
base_folder=job_config.job.dump_folder,
8889
)
8990

9091
# apply data parallel

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def parallelize_deepseekv3(
113113
model_compile_enabled=model_compile_enabled,
114114
use_flex_attn=use_flex_attn,
115115
op_sac_save_list=_op_sac_save_list,
116+
base_folder=job_config.job.dump_folder,
116117
)
117118

118119
if model_compile_enabled:

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def parallelize_llama(
102102
model_compile_enabled=model_compile_enabled,
103103
use_flex_attn=use_flex_attn,
104104
op_sac_save_list=_op_sac_save_list,
105+
base_folder=job_config.job.dump_folder,
105106
)
106107

107108
# turn on per-TransformerBlock compile after AC wrapping and before FSDP

0 commit comments

Comments
 (0)