18
18
import torch
19
19
from torch import distributed as dist , nn
20
20
from torch .autograd .profiler import record_function
21
+ from torch .distributed ._tensor import DTensor
22
+ from torch .distributed ._tensor ._shards_wrapper import LocalShardsWrapper
21
23
from torch .nn .parallel import DistributedDataParallel
22
24
from torchrec .distributed .embedding_sharding import (
23
25
EmbeddingSharding ,
@@ -589,18 +591,20 @@ def _pre_load_state_dict_hook(
589
591
) -> None :
590
592
"""
591
593
Modify the destination state_dict for model parallel
592
- to transform from ShardedTensors into tensors
594
+ to transform from ShardedTensors/DTensors into tensors
593
595
"""
594
- for (
595
- table_name ,
596
- model_shards ,
597
- ) in self ._model_parallel_name_to_local_shards .items ():
596
+ for table_name in self ._model_parallel_name_to_local_shards .keys ():
598
597
key = f"{ prefix } embeddings.{ table_name } .weight"
599
-
598
+ # gather model shards from both DTensor and ShardedTensor maps
599
+ model_shards_sharded_tensor = self ._model_parallel_name_to_local_shards [
600
+ table_name
601
+ ]
602
+ model_shards_dtensor = self ._model_parallel_name_to_shards_wrapper [
603
+ table_name
604
+ ]
600
605
# If state_dict[key] is already a ShardedTensor, use its local shards
601
606
if isinstance (state_dict [key ], ShardedTensor ):
602
607
local_shards = state_dict [key ].local_shards ()
603
- # If no local shards, create an empty tensor
604
608
if len (local_shards ) == 0 :
605
609
state_dict [key ] = torch .empty (0 )
606
610
else :
@@ -612,27 +616,57 @@ def _pre_load_state_dict_hook(
612
616
).view (- 1 , dim )
613
617
else :
614
618
state_dict [key ] = local_shards [0 ].tensor .view (- 1 , dim )
615
- else :
619
+ elif isinstance (state_dict [key ], DTensor ):
620
+ shards_wrapper = state_dict [key ].to_local ()
621
+ local_shards = shards_wrapper .local_shards ()
622
+ dim = shards_wrapper .local_sizes ()[0 ][1 ]
623
+ if len (local_shards ) == 0 :
624
+ state_dict [key ] = torch .empty (0 )
625
+ elif len (local_shards ) > 1 :
626
+ # TODO - add multiple shards on rank support
627
+ raise RuntimeError (
628
+ f"Multiple shards on rank is not supported for DTensor yet, got { len (local_shards )} "
629
+ )
630
+ else :
631
+ state_dict [key ] = local_shards [0 ].view (- 1 , dim )
632
+ elif isinstance (state_dict [key ], torch .Tensor ):
616
633
local_shards = []
617
- for shard in model_shards :
618
- # Extract shard size and offsets for splicing
619
- shard_sizes = shard .metadata .shard_sizes
620
- shard_offsets = shard .metadata .shard_offsets
621
-
622
- # Prepare tensor by splicing and placing on appropriate device
623
- spliced_tensor = state_dict [key ][
624
- shard_offsets [0 ] : shard_offsets [0 ] + shard_sizes [0 ],
625
- shard_offsets [1 ] : shard_offsets [1 ] + shard_sizes [1 ],
626
- ].to (shard .tensor .get_device ())
627
-
628
- # Append spliced tensor into local shards
629
- local_shards .append (spliced_tensor )
630
-
634
+ if model_shards_sharded_tensor :
635
+ # splice according to sharded tensor metadata
636
+ for shard in model_shards_sharded_tensor :
637
+ # Extract shard size and offsets for splicing
638
+ shard_size = shard .metadata .shard_sizes
639
+ shard_offset = shard .metadata .shard_offsets
640
+
641
+ # Prepare tensor by splicing and placing on appropriate device
642
+ spliced_tensor = state_dict [key ][
643
+ shard_offset [0 ] : shard_offset [0 ] + shard_size [0 ],
644
+ shard_offset [1 ] : shard_offset [1 ] + shard_size [1 ],
645
+ ]
646
+
647
+ # Append spliced tensor into local shards
648
+ local_shards .append (spliced_tensor )
649
+ elif model_shards_dtensor :
650
+ # splice according to dtensor metadata
651
+ for tensor , shard_offset in zip (
652
+ model_shards_dtensor ["local_tensors" ],
653
+ model_shards_dtensor ["local_offsets" ],
654
+ ):
655
+ shard_size = tensor .size ()
656
+ spliced_tensor = state_dict [key ][
657
+ shard_offset [0 ] : shard_offset [0 ] + shard_size [0 ],
658
+ shard_offset [1 ] : shard_offset [1 ] + shard_size [1 ],
659
+ ]
660
+ local_shards .append (spliced_tensor )
631
661
state_dict [key ] = (
632
662
torch .empty (0 )
633
663
if not local_shards
634
664
else torch .cat (local_shards , dim = 0 )
635
665
)
666
+ else :
667
+ raise RuntimeError (
668
+ f"Unexpected state_dict key type { type (state_dict [key ])} found for { key } "
669
+ )
636
670
637
671
for lookup in self ._lookups :
638
672
while isinstance (lookup , DistributedDataParallel ):
@@ -649,7 +683,9 @@ def _initialize_torch_state(self) -> None: # noqa
649
683
for table_name in self ._table_names :
650
684
self .embeddings [table_name ] = nn .Module ()
651
685
self ._model_parallel_name_to_local_shards = OrderedDict ()
686
+ self ._model_parallel_name_to_shards_wrapper = OrderedDict ()
652
687
self ._model_parallel_name_to_sharded_tensor = OrderedDict ()
688
+ self ._model_parallel_name_to_dtensor = OrderedDict ()
653
689
model_parallel_name_to_compute_kernel : Dict [str , str ] = {}
654
690
for (
655
691
table_name ,
@@ -658,6 +694,9 @@ def _initialize_torch_state(self) -> None: # noqa
658
694
if parameter_sharding .sharding_type == ShardingType .DATA_PARALLEL .value :
659
695
continue
660
696
self ._model_parallel_name_to_local_shards [table_name ] = []
697
+ self ._model_parallel_name_to_shards_wrapper [table_name ] = OrderedDict (
698
+ [("local_tensors" , []), ("local_offsets" , [])]
699
+ )
661
700
model_parallel_name_to_compute_kernel [table_name ] = (
662
701
parameter_sharding .compute_kernel
663
702
)
@@ -679,18 +718,29 @@ def _initialize_torch_state(self) -> None: # noqa
679
718
# save local_shards for transforming MP params to shardedTensor
680
719
for key , v in lookup .state_dict ().items ():
681
720
table_name = key [: - len (".weight" )]
682
- self ._model_parallel_name_to_local_shards [table_name ].extend (
683
- v .local_shards ()
684
- )
721
+ if isinstance (v , DTensor ):
722
+ shards_wrapper = self ._model_parallel_name_to_shards_wrapper [
723
+ table_name
724
+ ]
725
+ local_shards_wrapper = v ._local_tensor
726
+ shards_wrapper ["local_tensors" ].extend (local_shards_wrapper .local_shards ()) # pyre-ignore[16]
727
+ shards_wrapper ["local_offsets" ].extend (local_shards_wrapper .local_offsets ()) # pyre-ignore[16]
728
+ shards_wrapper ["global_size" ] = v .size ()
729
+ shards_wrapper ["global_stride" ] = v .stride ()
730
+ shards_wrapper ["placements" ] = v .placements
731
+ elif isinstance (v , ShardedTensor ):
732
+ self ._model_parallel_name_to_local_shards [table_name ].extend (
733
+ v .local_shards ()
734
+ )
685
735
for (
686
736
table_name ,
687
737
tbe_slice ,
688
738
) in lookup .named_parameters_by_table ():
689
739
self .embeddings [table_name ].register_parameter ("weight" , tbe_slice )
690
- for (
691
- table_name ,
692
- local_shards ,
693
- ) in self . _model_parallel_name_to_local_shards . items ():
740
+ for table_name in self . _model_parallel_name_to_local_shards . keys ():
741
+ local_shards = self . _model_parallel_name_to_local_shards [ table_name ]
742
+ shards_wrapper_map = self . _model_parallel_name_to_shards_wrapper [ table_name ]
743
+
694
744
# for shards that don't exist on this rank, register with empty tensor
695
745
if not hasattr (self .embeddings [table_name ], "weight" ):
696
746
self .embeddings [table_name ].register_parameter (
@@ -703,18 +753,34 @@ def _initialize_torch_state(self) -> None: # noqa
703
753
self .embeddings [table_name ].weight ._in_backward_optimizers = [
704
754
EmptyFusedOptimizer ()
705
755
]
756
+
706
757
if model_parallel_name_to_compute_kernel [table_name ] in {
707
758
EmbeddingComputeKernel .KEY_VALUE .value
708
759
}:
709
760
continue
710
- # created ShardedTensors once in init, use in post_state_dict_hook
711
- self ._model_parallel_name_to_sharded_tensor [table_name ] = (
712
- ShardedTensor ._init_from_local_shards (
713
- local_shards ,
714
- self ._name_to_table_size [table_name ],
715
- process_group = self ._env .process_group ,
761
+
762
+ if shards_wrapper_map ["local_tensors" ]:
763
+ self ._model_parallel_name_to_dtensor [table_name ] = DTensor .from_local (
764
+ local_tensor = LocalShardsWrapper (
765
+ local_shards = shards_wrapper_map ["local_tensors" ],
766
+ local_offsets = shards_wrapper_map ["local_offsets" ],
767
+ ),
768
+ device_mesh = self ._env .device_mesh ,
769
+ placements = shards_wrapper_map ["placements" ],
770
+ shape = shards_wrapper_map ["global_size" ],
771
+ stride = shards_wrapper_map ["global_stride" ],
772
+ run_check = False ,
773
+ )
774
+ else :
775
+ # if DTensors for table do not exist, create ShardedTensor
776
+ # created ShardedTensors once in init, use in post_state_dict_hook
777
+ self ._model_parallel_name_to_sharded_tensor [table_name ] = (
778
+ ShardedTensor ._init_from_local_shards (
779
+ local_shards ,
780
+ self ._name_to_table_size [table_name ],
781
+ process_group = self ._env .process_group ,
782
+ )
716
783
)
717
- )
718
784
719
785
def post_state_dict_hook (
720
786
module : ShardedEmbeddingCollection ,
@@ -729,6 +795,12 @@ def post_state_dict_hook(
729
795
) in module ._model_parallel_name_to_sharded_tensor .items ():
730
796
destination_key = f"{ prefix } embeddings.{ table_name } .weight"
731
797
destination [destination_key ] = sharded_t
798
+ for (
799
+ table_name ,
800
+ d_tensor ,
801
+ ) in module ._model_parallel_name_to_dtensor .items ():
802
+ destination_key = f"{ prefix } embeddings.{ table_name } .weight"
803
+ destination [destination_key ] = d_tensor
732
804
733
805
self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
734
806
self ._register_state_dict_hook (post_state_dict_hook )
0 commit comments