27
27
28
28
import torch
29
29
from fbgemm_gpu .permute_pooled_embedding_modules import PermutePooledEmbeddings
30
+ from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
31
+ DenseTableBatchedEmbeddingBagsCodegen ,
32
+ )
30
33
from tensordict import TensorDict
31
34
from torch import distributed as dist , nn , Tensor
32
35
from torch .autograd .profiler import record_function
50
53
)
51
54
from torchrec .distributed .sharding .cw_sharding import CwPooledEmbeddingSharding
52
55
from torchrec .distributed .sharding .dp_sharding import DpPooledEmbeddingSharding
56
+ from torchrec .distributed .sharding .dynamic_sharding_utils import (
57
+ shards_all_to_all ,
58
+ update_state_dict_post_resharding ,
59
+ )
53
60
from torchrec .distributed .sharding .grid_sharding import GridPooledEmbeddingSharding
54
61
from torchrec .distributed .sharding .rw_sharding import RwPooledEmbeddingSharding
55
62
from torchrec .distributed .sharding .tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635
642
self ._env = env
636
643
# output parameters as DTensor in state dict
637
644
self ._output_dtensor : bool = env .output_dtensor
638
-
639
- sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
640
- module ,
641
- table_name_to_parameter_sharding ,
642
- "embedding_bags." ,
643
- fused_params ,
645
+ self .sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = (
646
+ create_sharding_infos_by_sharding (
647
+ module ,
648
+ table_name_to_parameter_sharding ,
649
+ "embedding_bags." ,
650
+ fused_params ,
651
+ )
652
+ )
653
+ self ._sharding_types : List [str ] = list (
654
+ self .sharding_type_to_sharding_infos .keys ()
644
655
)
645
- self ._sharding_types : List [str ] = list (sharding_type_to_sharding_infos .keys ())
646
656
self ._embedding_shardings : List [
647
657
EmbeddingSharding [
648
658
EmbeddingShardingContext ,
@@ -658,7 +668,7 @@ def __init__(
658
668
permute_embeddings = True ,
659
669
qcomm_codecs_registry = self .qcomm_codecs_registry ,
660
670
)
661
- for embedding_configs in sharding_type_to_sharding_infos .values ()
671
+ for embedding_configs in self . sharding_type_to_sharding_infos .values ()
662
672
]
663
673
664
674
self ._is_weighted : bool = module .is_weighted ()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833
843
lookup = lookup .module
834
844
lookup .purge ()
835
845
836
- def _initialize_torch_state (self ) -> None : # noqa
846
+ def _initialize_torch_state (self , skip_registering : bool = False ) -> None : # noqa
837
847
"""
838
848
This provides consistency between this class and the EmbeddingBagCollection's
839
849
nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
1063
1073
destination_key = f"{ prefix } embedding_bags.{ table_name } .weight"
1064
1074
destination [destination_key ] = sharded_kvtensor
1065
1075
1066
- self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1067
- self ._register_state_dict_hook (post_state_dict_hook )
1068
- self ._register_load_state_dict_pre_hook (
1069
- self ._pre_load_state_dict_hook , with_module = True
1070
- )
1076
+ if not skip_registering :
1077
+ self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1078
+ self ._register_state_dict_hook (post_state_dict_hook )
1079
+ self ._register_load_state_dict_pre_hook (
1080
+ self ._pre_load_state_dict_hook , with_module = True
1081
+ )
1071
1082
self .reset_parameters ()
1072
1083
1073
1084
def reset_parameters (self ) -> None :
@@ -1164,6 +1175,40 @@ def _create_output_dist(self) -> None:
1164
1175
self ._uncombined_embedding_dims .extend (sharding .uncombined_embedding_dims ())
1165
1176
embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1166
1177
self ._dim_per_key = torch .tensor (self ._embedding_dims , device = self ._device )
1178
+
1179
+ embedding_shard_offsets : List [int ] = [
1180
+ meta .shard_offsets [1 ] if meta is not None else 0
1181
+ for meta in embedding_shard_metadata
1182
+ ]
1183
+ embedding_name_order : Dict [str , int ] = {}
1184
+ for i , name in enumerate (self ._uncombined_embedding_names ):
1185
+ embedding_name_order .setdefault (name , i )
1186
+
1187
+ def sort_key (input : Tuple [int , str ]) -> Tuple [int , int ]:
1188
+ index , name = input
1189
+ return (embedding_name_order [name ], embedding_shard_offsets [index ])
1190
+
1191
+ permute_indices = [
1192
+ i
1193
+ for i , _ in sorted (
1194
+ enumerate (self ._uncombined_embedding_names ), key = sort_key
1195
+ )
1196
+ ]
1197
+ self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
1198
+ self ._uncombined_embedding_dims , permute_indices , self ._device
1199
+ )
1200
+
1201
+ def _update_output_dist (self ) -> None :
1202
+ embedding_shard_metadata : List [Optional [ShardMetadata ]] = []
1203
+ # TODO: Optimize to only go through embedding shardings with new ranks
1204
+ self ._output_dists : List [nn .Module ] = []
1205
+ self ._embedding_names : List [str ] = []
1206
+ for sharding in self ._embedding_shardings :
1207
+ # TODO: if sharding type of table completely changes, need to regenerate everything
1208
+ self ._embedding_names .extend (sharding .embedding_names ())
1209
+ self ._output_dists .append (sharding .create_output_dist (device = self ._device ))
1210
+ embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1211
+
1167
1212
embedding_shard_offsets : List [int ] = [
1168
1213
meta .shard_offsets [1 ] if meta is not None else 0
1169
1214
for meta in embedding_shard_metadata
@@ -1399,6 +1444,105 @@ def compute_and_output_dist(
1399
1444
1400
1445
return awaitable
1401
1446
1447
+ def update_shards (
1448
+ self ,
1449
+ changed_sharding_params : Dict [str , ParameterSharding ], # NOTE: only delta
1450
+ env : ShardingEnv ,
1451
+ device : Optional [torch .device ],
1452
+ ) -> None :
1453
+ """
1454
+ Update shards for this module based on the changed_sharding_params. This will:
1455
+ 1. Move current lookup tensors to CPU
1456
+ 2. Purge lookups
1457
+ 3. Call shards_all_2_all containing collective to redistribute tensors
1458
+ 4. Update state_dict and other attributes to reflect new placements and shards
1459
+ 5. Create new lookups, and load in updated state_dict
1460
+
1461
+ Args:
1462
+ changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1463
+ table names to their new parameter sharding configs. This should only
1464
+ contain shards/table names that need to be moved.
1465
+ env (ShardingEnv): The sharding environment for the module.
1466
+ device (Optional[torch.device]): The device to place the updated module on.
1467
+ """
1468
+
1469
+ if env .output_dtensor :
1470
+ raise RuntimeError ("We do not yet support DTensor for resharding yet" )
1471
+ return
1472
+
1473
+ current_state = self .state_dict ()
1474
+ # TODO: Save Optimizers
1475
+
1476
+ saved_weights = {}
1477
+ # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1478
+ for i , lookup in enumerate (self ._lookups ):
1479
+ for attribute , tbe_module in lookup .named_modules ():
1480
+ if type (tbe_module ) is DenseTableBatchedEmbeddingBagsCodegen :
1481
+ saved_weights [str (i ) + "." + attribute ] = tbe_module .weights .cpu ()
1482
+ # Note: lookup.purge should delete tbe_module and weights
1483
+ # del tbe_module.weights
1484
+ # del tbe_module
1485
+ # pyre-ignore
1486
+ lookup .purge ()
1487
+
1488
+ # Deleting all lookups
1489
+ self ._lookups .clear ()
1490
+
1491
+ local_output_by_src_rank , local_output_tensor = shards_all_to_all (
1492
+ module = self ,
1493
+ device = device , # pyre-ignore
1494
+ changed_sharding_params = changed_sharding_params ,
1495
+ env = env ,
1496
+ )
1497
+
1498
+ current_state = update_state_dict_post_resharding (
1499
+ update_state_dict = current_state ,
1500
+ local_output_by_src_rank = local_output_by_src_rank ,
1501
+ local_output_tensor = local_output_tensor ,
1502
+ changed_sharding_params = changed_sharding_params ,
1503
+ curr_rank = dist .get_rank (),
1504
+ )
1505
+
1506
+ for name , param in changed_sharding_params .items ():
1507
+ self .module_sharding_plan [name ] = param
1508
+ # TODO: Support detecting old sharding type when sharding type is changing
1509
+ for sharding_info in self .sharding_type_to_sharding_infos [
1510
+ param .sharding_type
1511
+ ]:
1512
+ if sharding_info .embedding_config .name == name :
1513
+ sharding_info .param_sharding = param
1514
+
1515
+ self ._sharding_types : List [str ] = list (
1516
+ self .sharding_type_to_sharding_infos .keys ()
1517
+ )
1518
+ # TODO: Optimize to update only the changed embedding shardings
1519
+ self ._embedding_shardings : List [
1520
+ EmbeddingSharding [
1521
+ EmbeddingShardingContext ,
1522
+ KeyedJaggedTensor ,
1523
+ torch .Tensor ,
1524
+ torch .Tensor ,
1525
+ ]
1526
+ ] = [
1527
+ create_embedding_bag_sharding (
1528
+ embedding_configs ,
1529
+ env ,
1530
+ device ,
1531
+ permute_embeddings = True ,
1532
+ qcomm_codecs_registry = self .qcomm_codecs_registry ,
1533
+ )
1534
+ for embedding_configs in self .sharding_type_to_sharding_infos .values ()
1535
+ ]
1536
+
1537
+ self ._create_lookups ()
1538
+ self ._update_output_dist ()
1539
+
1540
+ if env .process_group and dist .get_backend (env .process_group ) != "fake" :
1541
+ self ._initialize_torch_state (skip_registering = True )
1542
+
1543
+ self .load_state_dict (current_state )
1544
+ return
1545
+
1402
1546
@property
1403
1547
def fused_optimizer (self ) -> KeyedOptimizer :
1404
1548
return self ._optim
@@ -1438,6 +1582,33 @@ def shardable_parameters(
1438
1582
for name , param in module .embedding_bags .named_parameters ()
1439
1583
}
1440
1584
1585
+ def reshard (
1586
+ self ,
1587
+ sharded_module : ShardedEmbeddingBagCollection ,
1588
+ changed_shard_to_params : Dict [str , ParameterSharding ],
1589
+ env : ShardingEnv ,
1590
+ device : Optional [torch .device ] = None ,
1591
+ ) -> ShardedEmbeddingBagCollection :
1592
+ """
1593
+ Updates the sharded module in place based on the changed_shard_to_params
1594
+ which contains the new ParameterSharding with different shard placements.
1595
+
1596
+ Args:
1597
+ sharded_module (ShardedEmbeddingBagCollection): The module to update
1598
+ changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1599
+ table names to their new parameter sharding configs. This should only
1600
+ contain shards/table names that need to be moved
1601
+ env (ShardingEnv): The sharding environment
1602
+ device (Optional[torch.device]): The device to place the updated module on
1603
+
1604
+ Returns:
1605
+ ShardedEmbeddingBagCollection: The updated sharded module
1606
+ """
1607
+
1608
+ if len (changed_shard_to_params ) > 0 :
1609
+ sharded_module .update_shards (changed_shard_to_params , env , device )
1610
+ return sharded_module
1611
+
1441
1612
@property
1442
1613
def module_type (self ) -> Type [EmbeddingBagCollection ]:
1443
1614
return EmbeddingBagCollection
0 commit comments