7
7
# This file provides the util functions to apply activation checkpointing to the model.
8
8
# Technically, this is not a part of distributed, but distributed module is the best place to put it.
9
9
10
+ import os
10
11
from collections import defaultdict
11
12
12
13
import torch
@@ -279,6 +280,7 @@ def apply_ac(
279
280
model_compile_enabled : bool = False ,
280
281
use_flex_attn : bool = False ,
281
282
op_sac_save_list : set [torch ._ops .OpOverload ] | None = None ,
283
+ base_folder : str = "" ,
282
284
) -> None :
283
285
"""Apply activation checkpointing to the model.
284
286
@@ -297,15 +299,29 @@ def apply_ac(
297
299
None
298
300
"""
299
301
300
- for layer_id , transformer_block in model .layers .named_children ():
301
- transformer_block = _apply_ac_to_transformer_block (
302
- transformer_block ,
303
- ac_config ,
304
- base_fqn = f"layers.{ layer_id } " ,
305
- model_compile_enabled = model_compile_enabled ,
306
- use_flex_attn = use_flex_attn ,
307
- op_sac_save_list = op_sac_save_list ,
308
- )
309
- model .layers .register_module (layer_id , transformer_block )
302
+ if ac_config .mode == "memory_budget" :
303
+ assert (model_compile_enabled is True ), "Memory budget mode requires model to be compiled"
304
+ torch ._functorch .config .activation_memory_budget_solver = ac_config .activation_memory_budget_solver
305
+ torch ._functorch .config .activation_memory_budget_runtime_estimator = ac_config .activation_memory_budget_runtime_estimator
306
+ if ac_config .visualize_memory_budget_pareto :
307
+ pareto_dir = os .path .join (base_folder , "memory_budget_pareto" )
308
+ if not os .path .exists (pareto_dir ):
309
+ os .makedirs (pareto_dir , exist_ok = True )
310
+ torch ._functorch .config .memory_budget_pareto_dir = pareto_dir
311
+ torch ._functorch .config .visualize_memory_budget_pareto = True
312
+
313
+ torch ._functorch .config .activation_memory_budget = ac_config .activation_memory_budget
314
+ logger .info (f"Selected { ac_config .activation_memory_budget } budget option" )
315
+ else :
316
+ for layer_id , transformer_block in model .layers .named_children ():
317
+ transformer_block = _apply_ac_to_transformer_block (
318
+ transformer_block ,
319
+ ac_config ,
320
+ base_fqn = f"layers.{ layer_id } " ,
321
+ model_compile_enabled = model_compile_enabled ,
322
+ use_flex_attn = use_flex_attn ,
323
+ op_sac_save_list = op_sac_save_list ,
324
+ )
325
+ model .layers .register_module (layer_id , transformer_block )
310
326
311
327
logger .info (f"Applied { ac_config .mode } activation checkpointing to the model" )
0 commit comments