diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 4ae6cb1aa2..53e13f64fa 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -49,7 +49,7 @@ from llama2_model import Transformer, ModelArgs from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import fully_shard from torch.distributed._tensor import Shard, Replicate from torch.distributed.tensor.parallel import ( parallelize_module, @@ -151,7 +151,7 @@ ) # Init FSDP using the dp device mesh -sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True) +sharded_model = fully_shard(model, mesh=dp_mesh) rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n")