From 67407e6a4bcf1ab9c34dcedb9b81b6425af49cf1 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 16 Oct 2024 17:09:49 -0700 Subject: [PATCH 01/16] enable FSDP2 cpuoffload --- torchtitan/config_manager.py | 10 ++++++++++ torchtitan/models/llama/model.py | 3 ++- torchtitan/parallelisms/parallelize_llama.py | 10 +++++++++- torchtitan/utils.py | 6 +++++- train.py | 7 ++++++- 5 files changed, 32 insertions(+), 4 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 88e51f0270..f617620120 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -253,6 +253,16 @@ def __init__(self): can be negative. 1 means disabled.""", ) + self.parser.add_argument( + "--training.offload_policy", + type=bool, + default=False, + help=""" + The `offload_policy` argument specifies the offloading policy for FSDP, + whether to offload parameters to CPU when not involved in computation. + If True, then this offloads gradients to CPU as well, meaning that the + optimizer step runs on CPU. """, + ) self.parser.add_argument( "--training.tensor_parallel_degree", type=int, diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 7f102a8012..a8717772eb 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -391,7 +391,8 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - with torch.device(self.freqs_cis.device): + # with torch.device(self.freqs_cis.device): + with torch.device("cuda"): self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fc26703db0..f301de788d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -12,7 +12,11 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + fully_shard, + MixedPrecisionPolicy, +) from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -88,6 +92,7 @@ def parallelize_llama( ], tp_enabled=parallel_dims.tp_enabled, pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.offload_policy, ) if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") @@ -299,12 +304,15 @@ def apply_fsdp( reduce_dtype: torch.dtype, tp_enabled: bool, pp_enabled: bool, + cpu_offload: bool = False, ): """ Apply data parallelism to the model. FSDP2 is used here. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() # TODO: remove this check once PyTorch 2.5 is released. We can safely assume # that users won't use a nightly build which is older than 20240809 by then. diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 7c562b47bf..8f8523e23b 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -117,8 +117,12 @@ def init_distributed(job_config): os.makedirs(dump_dir, exist_ok=True) _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") + backend = "nccl" + if job_config.training.offload_policy: + backend = "cuda:nccl,cpu:gloo" torch.distributed.init_process_group( - "nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds) + backend=backend, + timeout=timedelta(seconds=job_config.comm.init_timeout_seconds), ) # to mitigate the memory issue that collectives using diff --git a/train.py b/train.py index 3e8994a34d..8aaf757d12 100644 --- a/train.py +++ b/train.py @@ -170,7 +170,12 @@ def loss_fn(pred, labels): models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) # move sharded model to CPU/GPU and initialize weights via DTensor - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" + init_device = ( + "cpu" + if job_config.checkpoint.create_seed_checkpoint + or job_config.training.offload_policy + else "cuda" + ) model.to_empty(device=init_device) model.init_weights() model.train() From 9a8c90d794b45189e8cd7f2fdb984bad2ed07b5c Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 17 Oct 2024 16:21:57 -0700 Subject: [PATCH 02/16] rename config --- torchtitan/config_manager.py | 8 ++++---- torchtitan/models/llama/model.py | 1 - torchtitan/parallelisms/parallelize_llama.py | 2 +- torchtitan/utils.py | 2 +- train.py | 7 ++++++- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index f617620120..7cd7ef3baf 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -254,14 +254,14 @@ def __init__(self): 1 means disabled.""", ) self.parser.add_argument( - "--training.offload_policy", + "--training.enable_cpu_offload", type=bool, default=False, help=""" - The `offload_policy` argument specifies the offloading policy for FSDP, - whether to offload parameters to CPU when not involved in computation. + The `enable_cpu_offload` argument specifies whether to have offloading policy + for FSDP, offload parameters to CPU when not involved in computation. If True, then this offloads gradients to CPU as well, meaning that the - optimizer step runs on CPU. """, + optimizer step runs on CPU.""", ) self.parser.add_argument( "--training.tensor_parallel_degree", diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index a8717772eb..4df494a1f0 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -391,7 +391,6 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - # with torch.device(self.freqs_cis.device): with torch.device("cuda"): self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index f301de788d..5bd56f7d86 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -92,7 +92,7 @@ def parallelize_llama( ], tp_enabled=parallel_dims.tp_enabled, pp_enabled=parallel_dims.pp_enabled, - cpu_offload=job_config.training.offload_policy, + cpu_offload=job_config.training.enable_cpu_offload, ) if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 8f8523e23b..40e5a981a5 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -118,7 +118,7 @@ def init_distributed(job_config): _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") backend = "nccl" - if job_config.training.offload_policy: + if job_config.training.enable_cpu_offload: backend = "cuda:nccl,cpu:gloo" torch.distributed.init_process_group( backend=backend, diff --git a/train.py b/train.py index 8aaf757d12..1487c0f7bd 100644 --- a/train.py +++ b/train.py @@ -173,7 +173,7 @@ def loss_fn(pred, labels): init_device = ( "cpu" if job_config.checkpoint.create_seed_checkpoint - or job_config.training.offload_policy + or job_config.training.enable_cpu_offload else "cuda" ) model.to_empty(device=init_device) @@ -425,7 +425,12 @@ def loss_fn(pred, labels): if __name__ == "__main__": + torch.cuda.memory._record_memory_history(max_entries=100000) config = JobConfig() config.parse_args() main(config) torch.distributed.destroy_process_group() + import pickle + + snapshot = torch.cuda.memory._snapshot() + pickle.dump(snapshot, open("your_name.pickle", "wb")) From 530683dade32da35592ba34fb32d76647f229e07 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 17 Oct 2024 16:27:39 -0700 Subject: [PATCH 03/16] lint fix --- torchtitan/config_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 7cd7ef3baf..d15f171d41 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -258,9 +258,9 @@ def __init__(self): type=bool, default=False, help=""" - The `enable_cpu_offload` argument specifies whether to have offloading policy - for FSDP, offload parameters to CPU when not involved in computation. - If True, then this offloads gradients to CPU as well, meaning that the + The `enable_cpu_offload` argument specifies whether to have offloading policy + for FSDP, offload parameters to CPU when not involved in computation. + If True, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU.""", ) self.parser.add_argument( From 01b0662fb035040b1055c3b5003902124a859db1 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 21 Oct 2024 15:41:51 -0700 Subject: [PATCH 04/16] update cpu offload config --- torchtitan/config_manager.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d15f171d41..f4c3131a01 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -259,9 +259,8 @@ def __init__(self): default=False, help=""" The `enable_cpu_offload` argument specifies whether to have offloading policy - for FSDP, offload parameters to CPU when not involved in computation. - If True, then this offloads gradients to CPU as well, meaning that the - optimizer step runs on CPU.""", + for FSDP. If True, CPU offload of parameters, gradients, and optimizer states + will be supported.""", ) self.parser.add_argument( "--training.tensor_parallel_degree", From 9569e57628df4b366e2630eebea45a43bc2c479d Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 22 Oct 2024 19:54:36 -0700 Subject: [PATCH 05/16] manage freqs_cis as nn.Parameter --- torchtitan/models/llama/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 4df494a1f0..4de0cf76a0 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -367,7 +367,10 @@ def __init__(self, model_args: ModelArgs): # initialized by the checkpoint, or we need to add a separate initializer for # just the non-persistent buffers that is called after loading checkpoints. self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) - + with torch.device("cpu"): + self.freqs_cis = nn.Parameter( + self._precompute_freqs_cis().to(torch.float32), requires_grad=False + ) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) @@ -391,8 +394,6 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - with torch.device("cuda"): - self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): From 30c204f4fa0f94e758561ac5bc2cbff74ea16f42 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 23 Oct 2024 14:10:34 -0700 Subject: [PATCH 06/16] init freqs_cis at meta device --- torchtitan/models/llama/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 4de0cf76a0..bbe2253ad9 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -367,10 +367,9 @@ def __init__(self, model_args: ModelArgs): # initialized by the checkpoint, or we need to add a separate initializer for # just the non-persistent buffers that is called after loading checkpoints. self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) - with torch.device("cpu"): - self.freqs_cis = nn.Parameter( - self._precompute_freqs_cis().to(torch.float32), requires_grad=False - ) + self.freqs_cis = nn.Parameter( + self._precompute_freqs_cis().to(torch.float32), requires_grad=False + ) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) @@ -396,6 +395,8 @@ def init_weights(self): """ if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) + if self.freqs_cis is not None: + nn.init.normal_(self.freqs_cis) for layer in self.layers.values(): if layer is not None: layer.init_weights() From 7d36b87cfbda65d43b1f53d3d7152d7bed66a805 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 23 Oct 2024 14:15:47 -0700 Subject: [PATCH 07/16] remove memory snapshot --- train.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/train.py b/train.py index 1487c0f7bd..1ce5415b7c 100644 --- a/train.py +++ b/train.py @@ -425,12 +425,7 @@ def loss_fn(pred, labels): if __name__ == "__main__": - torch.cuda.memory._record_memory_history(max_entries=100000) config = JobConfig() config.parse_args() main(config) torch.distributed.destroy_process_group() - import pickle - - snapshot = torch.cuda.memory._snapshot() - pickle.dump(snapshot, open("your_name.pickle", "wb")) From c691a6d249d91a6a3173e922f1eb21045ee29037 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 23 Oct 2024 18:03:14 -0700 Subject: [PATCH 08/16] change default device, not set nn.Parameters --- torchtitan/models/llama/model.py | 7 ++----- train.py | 5 +++++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 077a6a1e2d..4dc9090ac7 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -367,9 +367,7 @@ def __init__(self, model_args: ModelArgs): # initialized by the checkpoint, or we need to add a separate initializer for # just the non-persistent buffers that is called after loading checkpoints. self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) - self.freqs_cis = nn.Parameter( - self._precompute_freqs_cis().to(torch.float32), requires_grad=False - ) + self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) @@ -393,10 +391,9 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ + self.freqs_cis = self._precompute_freqs_cis().to() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) - if self.freqs_cis is not None: - nn.init.normal_(self.freqs_cis) for layer in self.layers.values(): if layer is not None: layer.init_weights() diff --git a/train.py b/train.py index 8919cbbfef..4fdf31b1a3 100644 --- a/train.py +++ b/train.py @@ -111,6 +111,9 @@ def main(job_config: JobConfig): logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") with torch.device("meta"): model = model_cls.from_model_args(model_config) + if job_config.training.enable_cpu_offload: + origin_default_device = torch.get_default_device() + torch.set_default_device("cuda") # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) @@ -171,6 +174,8 @@ def loss_fn(pred, labels): model_parts = [model] + if job_config.training.enable_cpu_offload: + torch.set_default_device(origin_default_device) gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( f"GPU memory usage for model: " From 3dfca1be8b99fcdad2eeb567e98acb201e828f8e Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 23 Oct 2024 18:10:31 -0700 Subject: [PATCH 09/16] typo --- torchtitan/models/llama/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 4dc9090ac7..1fd13f1498 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -391,7 +391,7 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - self.freqs_cis = self._precompute_freqs_cis().to() + self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): From 6fef39a387133067278c7d629ab8bd3d938a2acf Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 23 Oct 2024 18:31:14 -0700 Subject: [PATCH 10/16] use a context manager for model.init_weights --- train.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 4fdf31b1a3..54e2482ab9 100644 --- a/train.py +++ b/train.py @@ -111,9 +111,6 @@ def main(job_config: JobConfig): logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") with torch.device("meta"): model = model_cls.from_model_args(model_config) - if job_config.training.enable_cpu_offload: - origin_default_device = torch.get_default_device() - torch.set_default_device("cuda") # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) @@ -169,13 +166,15 @@ def loss_fn(pred, labels): else "cuda" ) model.to_empty(device=init_device) - model.init_weights() + if job_config.training.enable_cpu_offload: + with torch.device("cuda"): + model.init_weights() + else: + model.init_weights() model.train() model_parts = [model] - if job_config.training.enable_cpu_offload: - torch.set_default_device(origin_default_device) gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( f"GPU memory usage for model: " From edfe0a68234a563c3699aaae63ff929a4f92e8af Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 24 Oct 2024 12:55:49 -0700 Subject: [PATCH 11/16] correct condition for non_cpu_offloading case --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 54e2482ab9..e17b1d777c 100644 --- a/train.py +++ b/train.py @@ -166,7 +166,7 @@ def loss_fn(pred, labels): else "cuda" ) model.to_empty(device=init_device) - if job_config.training.enable_cpu_offload: + if not job_config.checkpoint.create_seed_checkpoint: with torch.device("cuda"): model.init_weights() else: From f6f9393875d168f99ea807fd1bb9534fd6e608f8 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 24 Oct 2024 16:26:04 -0700 Subject: [PATCH 12/16] add buffer_device as optional input --- test_runner.py | 2 ++ torchtitan/models/llama/model.py | 12 ++++++++++-- train.py | 7 ++----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/test_runner.py b/test_runner.py index 61031d742f..0c80cc0061 100755 --- a/test_runner.py +++ b/test_runner.py @@ -57,10 +57,12 @@ def build_test_list(): [ [ "--training.compile", + "--training.enable_cpu_offload True", ], ], "1D compile", "1d_compile", + ngpu=8, ), OverrideDefinitions( [ diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 1fd13f1498..8644629c29 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -379,7 +379,10 @@ def __init__(self, model_args: ModelArgs): self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) self.init_weights() - def init_weights(self): + def init_weights( + self, + buffer_device: Optional[torch.device] = None, + ): """ [Note: On ``init_weights`` vs. ``reset_parameters``] Modules may define ``reset_parameters`` to initialize parameter values. @@ -391,7 +394,12 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - self.freqs_cis = self._precompute_freqs_cis() + if buffer_device is not None: + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + else: + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): diff --git a/train.py b/train.py index e17b1d777c..96fe5915e0 100644 --- a/train.py +++ b/train.py @@ -166,11 +166,8 @@ def loss_fn(pred, labels): else "cuda" ) model.to_empty(device=init_device) - if not job_config.checkpoint.create_seed_checkpoint: - with torch.device("cuda"): - model.init_weights() - else: - model.init_weights() + buffer_device = "cuda" if job_config.training.enable_cpu_offload else None + model.init_weights(buffer_device=buffer_device) model.train() model_parts = [model] From d6840dd8de513547ad9e953ed2608a386a0bfc8f Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 24 Oct 2024 16:27:11 -0700 Subject: [PATCH 13/16] add ngpu for 1D test --- test_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_runner.py b/test_runner.py index 0c80cc0061..4fc9517221 100755 --- a/test_runner.py +++ b/test_runner.py @@ -62,7 +62,7 @@ def build_test_list(): ], "1D compile", "1d_compile", - ngpu=8, + ngpu=2, ), OverrideDefinitions( [ From ab1e25819b75cddef9bfa24f42e3b4f455dc1316 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 24 Oct 2024 19:51:48 -0700 Subject: [PATCH 14/16] modify config help, update condition logic --- test_runner.py | 12 ++++++++++-- torchtitan/config_manager.py | 4 +--- torchtitan/models/llama/model.py | 9 +++------ 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/test_runner.py b/test_runner.py index 4fc9517221..1229dcd166 100755 --- a/test_runner.py +++ b/test_runner.py @@ -57,12 +57,10 @@ def build_test_list(): [ [ "--training.compile", - "--training.enable_cpu_offload True", ], ], "1D compile", "1d_compile", - ngpu=2, ), OverrideDefinitions( [ @@ -353,6 +351,16 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.enable_cpu_offload True", + ], + ], + "Enable CPU Offload", + "enable_cpu_offload", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index c5feec1926..defc010e03 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -258,9 +258,7 @@ def __init__(self): type=bool, default=False, help=""" - The `enable_cpu_offload` argument specifies whether to have offloading policy - for FSDP. If True, CPU offload of parameters, gradients, and optimizer states - will be supported.""", + Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""", ) self.parser.add_argument( "--training.tensor_parallel_degree", diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 8644629c29..a3bae18a19 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -394,12 +394,9 @@ def init_weights( ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - if buffer_device is not None: - with torch.device(buffer_device): - self.freqs_cis = self._precompute_freqs_cis() - else: - with torch.device(self.freqs_cis.device): - self.freqs_cis = self._precompute_freqs_cis() + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): From 2ca9882b653a19c4f84376e1a8758cf63a124a76 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 28 Oct 2024 12:01:25 -0700 Subject: [PATCH 15/16] test cpu offload with pp --- test_runner.py | 7 ++++--- train.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test_runner.py b/test_runner.py index 1229dcd166..d98b49b967 100755 --- a/test_runner.py +++ b/test_runner.py @@ -354,12 +354,13 @@ def build_test_list(): OverrideDefinitions( [ [ + "--experimental.pipeline_parallel_degree 2", "--training.enable_cpu_offload True", ], ], - "Enable CPU Offload", - "enable_cpu_offload", - ngpu=2, + "Enable CPU Offload with PP", + "enable_cpu_offload+PP", + ngpu=4, ), ] return integration_tests_flavors diff --git a/train.py b/train.py index 96fe5915e0..f12632de44 100644 --- a/train.py +++ b/train.py @@ -151,8 +151,10 @@ def loss_fn(pred, labels): for m in model_parts: # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - m.to_empty(device="cuda") - m.init_weights() + init_device = "cpu" if job_config.training.enable_cpu_offload else "cuda" + m.to_empty(device=init_device) + buffer_device = "cuda" if job_config.training.enable_cpu_offload else None + m.init_weights(buffer_device=buffer_device) m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel From 7af331cdf9ecc168ddd857b7a3963fd7e8c783ab Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 28 Oct 2024 13:21:31 -0700 Subject: [PATCH 16/16] move init_device and buffer_device outside pp condition --- train.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index f12632de44..bc04dad0ad 100644 --- a/train.py +++ b/train.py @@ -138,6 +138,17 @@ def loss_fn(pred, labels): if job_config.training.compile: loss_fn = torch.compile(loss_fn) + # move sharded model to CPU/GPU and initialize weights via DTensor + if job_config.checkpoint.create_seed_checkpoint: + init_device = "cpu" + buffer_device = None + elif job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = "cuda" + else: + init_device = "cuda" + buffer_device = None + # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel @@ -151,24 +162,13 @@ def loss_fn(pred, labels): for m in model_parts: # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - init_device = "cpu" if job_config.training.enable_cpu_offload else "cuda" m.to_empty(device=init_device) - buffer_device = "cuda" if job_config.training.enable_cpu_offload else None m.init_weights(buffer_device=buffer_device) m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) - - # move sharded model to CPU/GPU and initialize weights via DTensor - init_device = ( - "cpu" - if job_config.checkpoint.create_seed_checkpoint - or job_config.training.enable_cpu_offload - else "cuda" - ) model.to_empty(device=init_device) - buffer_device = "cuda" if job_config.training.enable_cpu_offload else None model.init_weights(buffer_device=buffer_device) model.train()