Skip to content

[not for land] Use new AC #1294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/soulitzer/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/integration_test_8gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ jobs:

pip config --user set global.progress_bar off

git clone https://github.com/soulitzer/ac-experimental.git && cd ac-experimental && pip install -e . && cd ..

python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126

USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
Expand Down
82 changes: 52 additions & 30 deletions torchtitan/models/llama3/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,36 @@
import torch
import torch.nn as nn
from torch.distributed._composable.replicate import replicate

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
ActivationWrapper,
)


class CheckpointWrapper(ActivationWrapper):
def __init__(self, mod: torch.nn.Module, **kwargs):
super().__init__(mod)
self._checkpoint_wrapped_module = mod
self._make_policy_fn = kwargs.get("make_policy_fn", None)

def forward(self, *args, **kwargs):
from ac_experimental import apply_ac_policy_fn

if self._make_policy_fn is None:
return apply_ac_policy_fn(
self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all"
)
else:
# Pass is_factory=True so that a new instance of policy_fn is created per AC invocation
return apply_ac_policy_fn(
self._checkpoint_wrapped_module, *args, **kwargs, policy_fn=self._make_policy_fn, is_factory=True
)


def ptd_checkpoint_wrapper(mod, **kwargs):
return CheckpointWrapper(mod, **kwargs)


from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
from torch.distributed.tensor import Replicate, Shard
Expand Down Expand Up @@ -226,6 +252,29 @@ def apply_tp(
torch.ops.aten.max.default,
}

from torch.utils.checkpoint import CheckpointPolicy

# If you want your policy to have state, pass a class. Make sure to
# create it in global scope to avoid new instances triggering recompiles.
class CustomPolicy:
def __init__(self):
super().__init__()
self.meta = dict()

def __call__(self, ctx, out, func, *args, **kwargs):
mm_count_key = f"mm_count"
if func == torch.ops.aten.mm.default:
self.meta[mm_count_key] = self.meta.get(mm_count_key, 0) + 1

# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
func == torch.ops.aten.mm.default and self.meta[mm_count_key] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
if to_save
else CheckpointPolicy.PREFER_RECOMPUTE
)

def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
valid_ac_modes = ("full", "selective")
Expand All @@ -246,38 +295,11 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
f"Valid options: 'op' or a positive int representing layer frequency"
)
if use_op_sac:
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)

def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
if to_save
else CheckpointPolicy.PREFER_RECOMPUTE
)

return _custom_policy

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
make_policy_fn=CustomPolicy,
)

elif use_layer_sac:
# Checkpoint every `ac_freq` of the modules passed to this function
ac_freq = int(ac_config.selective_ac_option)
Expand Down
Loading