Skip to content

Commit 0e30a20

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Dynamic Sharding API + Test for EBC, TW, ShardedTensor (#2852)
Summary: Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs. What's added here: 1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection` 2. Util functions for dynamic sharding - these are used by the `update_shards` API: 1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight` 2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params` 3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from various: `world_sizes`, `num_tables`, `data_types`. 1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 Future work items (features not yet supported in this diff): * CW, RW, and many other sharding types * Optimizer saving * DTensor implementation Differential Revision: D69095169
1 parent f0ae23d commit 0e30a20

File tree

3 files changed

+783
-14
lines changed

3 files changed

+783
-14
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 185 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
31+
DenseTableBatchedEmbeddingBagsCodegen,
32+
)
3033
from tensordict import TensorDict
3134
from torch import distributed as dist, nn, Tensor
3235
from torch.autograd.profiler import record_function
@@ -50,6 +53,10 @@
5053
)
5154
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5255
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
56+
from torchrec.distributed.sharding.dynamic_sharding_utils import (
57+
shards_all_to_all,
58+
update_state_dict_post_resharding,
59+
)
5360
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
5461
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
5562
from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635642
self._env = env
636643
# output parameters as DTensor in state dict
637644
self._output_dtensor: bool = env.output_dtensor
638-
639-
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
640-
module,
641-
table_name_to_parameter_sharding,
642-
"embedding_bags.",
643-
fused_params,
645+
self.sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = (
646+
create_sharding_infos_by_sharding(
647+
module,
648+
table_name_to_parameter_sharding,
649+
"embedding_bags.",
650+
fused_params,
651+
)
652+
)
653+
self._sharding_types: List[str] = list(
654+
self.sharding_type_to_sharding_infos.keys()
644655
)
645-
self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys())
646656
self._embedding_shardings: List[
647657
EmbeddingSharding[
648658
EmbeddingShardingContext,
@@ -658,7 +668,7 @@ def __init__(
658668
permute_embeddings=True,
659669
qcomm_codecs_registry=self.qcomm_codecs_registry,
660670
)
661-
for embedding_configs in sharding_type_to_sharding_infos.values()
671+
for embedding_configs in self.sharding_type_to_sharding_infos.values()
662672
]
663673

664674
self._is_weighted: bool = module.is_weighted()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833843
lookup = lookup.module
834844
lookup.purge()
835845

836-
def _initialize_torch_state(self) -> None: # noqa
846+
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
837847
"""
838848
This provides consistency between this class and the EmbeddingBagCollection's
839849
nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
10631073
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
10641074
destination[destination_key] = sharded_kvtensor
10651075

1066-
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
1067-
self._register_state_dict_hook(post_state_dict_hook)
1068-
self._register_load_state_dict_pre_hook(
1069-
self._pre_load_state_dict_hook, with_module=True
1070-
)
1076+
if not skip_registering:
1077+
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
1078+
self._register_state_dict_hook(post_state_dict_hook)
1079+
self._register_load_state_dict_pre_hook(
1080+
self._pre_load_state_dict_hook, with_module=True
1081+
)
10711082
self.reset_parameters()
10721083

10731084
def reset_parameters(self) -> None:
@@ -1164,6 +1175,40 @@ def _create_output_dist(self) -> None:
11641175
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())
11651176
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
11661177
self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device)
1178+
1179+
embedding_shard_offsets: List[int] = [
1180+
meta.shard_offsets[1] if meta is not None else 0
1181+
for meta in embedding_shard_metadata
1182+
]
1183+
embedding_name_order: Dict[str, int] = {}
1184+
for i, name in enumerate(self._uncombined_embedding_names):
1185+
embedding_name_order.setdefault(name, i)
1186+
1187+
def sort_key(input: Tuple[int, str]) -> Tuple[int, int]:
1188+
index, name = input
1189+
return (embedding_name_order[name], embedding_shard_offsets[index])
1190+
1191+
permute_indices = [
1192+
i
1193+
for i, _ in sorted(
1194+
enumerate(self._uncombined_embedding_names), key=sort_key
1195+
)
1196+
]
1197+
self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
1198+
self._uncombined_embedding_dims, permute_indices, self._device
1199+
)
1200+
1201+
def _update_output_dist(self) -> None:
1202+
embedding_shard_metadata: List[Optional[ShardMetadata]] = []
1203+
# TODO: Optimize to only go through embedding shardings with new ranks
1204+
self._output_dists: List[nn.Module] = []
1205+
self._embedding_names: List[str] = []
1206+
for sharding in self._embedding_shardings:
1207+
# TODO: if sharding type of table completely changes, need to regenerate everything
1208+
self._embedding_names.extend(sharding.embedding_names())
1209+
self._output_dists.append(sharding.create_output_dist(device=self._device))
1210+
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
1211+
11671212
embedding_shard_offsets: List[int] = [
11681213
meta.shard_offsets[1] if meta is not None else 0
11691214
for meta in embedding_shard_metadata
@@ -1399,6 +1444,105 @@ def compute_and_output_dist(
13991444

14001445
return awaitable
14011446

1447+
def update_shards(
1448+
self,
1449+
changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta
1450+
env: ShardingEnv,
1451+
device: Optional[torch.device],
1452+
) -> None:
1453+
"""
1454+
Update shards for this module based on the changed_sharding_params. This will:
1455+
1. Move current lookup tensors to CPU
1456+
2. Purge lookups
1457+
3. Call shards_all_2_all containing collective to redistribute tensors
1458+
4. Update state_dict and other attributes to reflect new placements and shards
1459+
5. Create new lookups, and load in updated state_dict
1460+
1461+
Args:
1462+
changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1463+
table names to their new parameter sharding configs. This should only
1464+
contain shards/table names that need to be moved.
1465+
env (ShardingEnv): The sharding environment for the module.
1466+
device (Optional[torch.device]): The device to place the updated module on.
1467+
"""
1468+
1469+
if env.output_dtensor:
1470+
raise RuntimeError("We do not yet support DTensor for resharding yet")
1471+
return
1472+
1473+
current_state = self.state_dict()
1474+
# TODO: Save Optimizers
1475+
1476+
saved_weights = {}
1477+
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1478+
for i, lookup in enumerate(self._lookups):
1479+
for attribute, tbe_module in lookup.named_modules():
1480+
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
1481+
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
1482+
# Note: lookup.purge should delete tbe_module and weights
1483+
# del tbe_module.weights
1484+
# del tbe_module
1485+
# pyre-ignore
1486+
lookup.purge()
1487+
1488+
# Deleting all lookups
1489+
self._lookups.clear()
1490+
1491+
local_output_by_src_rank, local_output_tensor = shards_all_to_all(
1492+
module=self,
1493+
device=device, # pyre-ignore
1494+
changed_sharding_params=changed_sharding_params,
1495+
env=env,
1496+
)
1497+
1498+
current_state = update_state_dict_post_resharding(
1499+
update_state_dict=current_state,
1500+
local_output_by_src_rank=local_output_by_src_rank,
1501+
local_output_tensor=local_output_tensor,
1502+
changed_sharding_params=changed_sharding_params,
1503+
curr_rank=dist.get_rank(),
1504+
)
1505+
1506+
for name, param in changed_sharding_params.items():
1507+
self.module_sharding_plan[name] = param
1508+
# TODO: Support detecting old sharding type when sharding type is changing
1509+
for sharding_info in self.sharding_type_to_sharding_infos[
1510+
param.sharding_type
1511+
]:
1512+
if sharding_info.embedding_config.name == name:
1513+
sharding_info.param_sharding = param
1514+
1515+
self._sharding_types: List[str] = list(
1516+
self.sharding_type_to_sharding_infos.keys()
1517+
)
1518+
# TODO: Optimize to update only the changed embedding shardings
1519+
self._embedding_shardings: List[
1520+
EmbeddingSharding[
1521+
EmbeddingShardingContext,
1522+
KeyedJaggedTensor,
1523+
torch.Tensor,
1524+
torch.Tensor,
1525+
]
1526+
] = [
1527+
create_embedding_bag_sharding(
1528+
embedding_configs,
1529+
env,
1530+
device,
1531+
permute_embeddings=True,
1532+
qcomm_codecs_registry=self.qcomm_codecs_registry,
1533+
)
1534+
for embedding_configs in self.sharding_type_to_sharding_infos.values()
1535+
]
1536+
1537+
self._create_lookups()
1538+
self._update_output_dist()
1539+
1540+
if env.process_group and dist.get_backend(env.process_group) != "fake":
1541+
self._initialize_torch_state(skip_registering=True)
1542+
1543+
self.load_state_dict(current_state)
1544+
return
1545+
14021546
@property
14031547
def fused_optimizer(self) -> KeyedOptimizer:
14041548
return self._optim
@@ -1438,6 +1582,33 @@ def shardable_parameters(
14381582
for name, param in module.embedding_bags.named_parameters()
14391583
}
14401584

1585+
def reshard(
1586+
self,
1587+
sharded_module: ShardedEmbeddingBagCollection,
1588+
changed_shard_to_params: Dict[str, ParameterSharding],
1589+
env: ShardingEnv,
1590+
device: Optional[torch.device] = None,
1591+
) -> ShardedEmbeddingBagCollection:
1592+
"""
1593+
Updates the sharded module in place based on the changed_shard_to_params
1594+
which contains the new ParameterSharding with different shard placements.
1595+
1596+
Args:
1597+
sharded_module (ShardedEmbeddingBagCollection): The module to update
1598+
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1599+
table names to their new parameter sharding configs. This should only
1600+
contain shards/table names that need to be moved
1601+
env (ShardingEnv): The sharding environment
1602+
device (Optional[torch.device]): The device to place the updated module on
1603+
1604+
Returns:
1605+
ShardedEmbeddingBagCollection: The updated sharded module
1606+
"""
1607+
1608+
if len(changed_shard_to_params) > 0:
1609+
sharded_module.update_shards(changed_shard_to_params, env, device)
1610+
return sharded_module
1611+
14411612
@property
14421613
def module_type(self) -> Type[EmbeddingBagCollection]:
14431614
return EmbeddingBagCollection

0 commit comments

Comments
 (0)