-
Notifications
You must be signed in to change notification settings - Fork 483
Add meta_init, enable it as default init process #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8699f1c
f211961
2b1871e
6a15267
bbc165f
8fea674
8b20a62
bef7cf9
de2e2a4
7b65935
eee05e8
4261c54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
from contextlib import contextmanager | ||
|
||
import torch | ||
from torch import nn | ||
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened | ||
|
||
|
||
@contextmanager | ||
def meta_model_init(): | ||
"""init model on meta device""" | ||
saved_register_parameter = nn.Module.register_parameter | ||
saved_register_buffer = nn.Module.register_buffer | ||
|
||
def register_meta_param(module, name, param): | ||
saved_register_parameter(module, name, param) | ||
if param is not None: | ||
param_cls = type(module._parameters[name]) | ||
kwargs = module._parameters[name].__dict__ | ||
module._parameters[name] = param_cls( | ||
module._parameters[name].to(torch.device("meta")), **kwargs | ||
) | ||
|
||
def register_meta_buffer(module, name, buffer): | ||
saved_register_buffer(module, name, buffer) | ||
if buffer is not None: | ||
module._buffers[name] = module._buffers[name].to(torch.device("meta")) | ||
|
||
try: | ||
nn.Module.register_parameter = register_meta_param | ||
nn.Module.register_buffer = register_meta_buffer | ||
yield | ||
finally: | ||
nn.Module.register_parameter = saved_register_parameter | ||
nn.Module.register_buffer = saved_register_buffer | ||
|
||
|
||
@torch.no_grad() | ||
def meta_to_real_init_fn(module: nn.Module): | ||
for submodule in module.modules(): | ||
for param_name, param in submodule.named_parameters(recurse=False): | ||
if not _is_fsdp_flattened(param) and param.is_meta: | ||
materialized_param = nn.Parameter( | ||
torch.randn_like(param, device=torch.device("cuda")) | ||
) | ||
setattr(submodule, param_name, materialized_param) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
) | ||
from torchtrain.config_manager import JobConfig | ||
from torchtrain.logging_utils import rank0_log | ||
from torchtrain.meta_init import meta_to_real_init_fn | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -193,6 +194,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
if parallel_dims.dp_enabled: | ||
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh | ||
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
|
||
fsdp_config = { | ||
"mixed_precision": MixedPrecision( | ||
param_dtype=torch.bfloat16, | ||
|
@@ -204,12 +206,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
# When torch.compile is active, it requires us to set use_orig_params=True | ||
"use_orig_params": True, | ||
"device_mesh": dp_mesh, | ||
"param_init_fn": meta_to_real_init_fn, | ||
} | ||
|
||
with enable_wrap(wrapper_cls=FSDP, **fsdp_config): | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
# before wrapping with FSDP, we need to make sure the layer is on GPU | ||
transformer_block = transformer_block.cuda() | ||
|
||
# apply selective AC | ||
transformer_block = checkpoint_wrapper( | ||
|
@@ -220,10 +221,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
model.layers[layer_id] = wrap(transformer_block) | ||
|
||
# wrap the rest layers with FSDP | ||
model = wrap(model.cuda()) | ||
model = wrap(model) | ||
|
||
rank0_log("Applied FSDP to the model...") | ||
else: | ||
model.cuda() | ||
|
||
# redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used | ||
model.cuda() | ||
lessw2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# we have now moved from meta to device, | ||
# reset parameters for proper initialization | ||
model.reset_parameters() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I apologize that FSDP meta-device init is confusing, but I think this might not be fully correct.
For this Llama case, it looks like perhaps the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct that the init does not depend on tensor shape directly. |
||
return model |
Uh oh!
There was an error while loading. Please reload this page.