Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 191 additions & 14 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
DenseTableBatchedEmbeddingBagsCodegen,
)
from tensordict import TensorDict
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
Expand All @@ -50,6 +53,10 @@
)
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
from torchrec.distributed.sharding.dynamic_sharding import (
shards_all_to_all,
update_state_dict_post_resharding,
)
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding
Expand Down Expand Up @@ -635,14 +642,17 @@ def __init__(
self._env = env
# output parameters as DTensor in state dict
self._output_dtensor: bool = env.output_dtensor

sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
module,
table_name_to_parameter_sharding,
"embedding_bags.",
fused_params,
self.sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = (
create_sharding_infos_by_sharding(
module,
table_name_to_parameter_sharding,
"embedding_bags.",
fused_params,
)
)
self._sharding_types: List[str] = list(
self.sharding_type_to_sharding_infos.keys()
)
self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys())
self._embedding_shardings: List[
EmbeddingSharding[
EmbeddingShardingContext,
Expand All @@ -658,7 +668,7 @@ def __init__(
permute_embeddings=True,
qcomm_codecs_registry=self.qcomm_codecs_registry,
)
for embedding_configs in sharding_type_to_sharding_infos.values()
for embedding_configs in self.sharding_type_to_sharding_infos.values()
]

self._is_weighted: bool = module.is_weighted()
Expand Down Expand Up @@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
lookup = lookup.module
lookup.purge()

def _initialize_torch_state(self) -> None: # noqa
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
"""
This provides consistency between this class and the EmbeddingBagCollection's
nn.Module API calls (state_dict, named_modules, etc)
Expand Down Expand Up @@ -1063,11 +1073,12 @@ def post_state_dict_hook(
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
destination[destination_key] = sharded_kvtensor

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
)
if not skip_registering:
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
)
self.reset_parameters()

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -1164,6 +1175,7 @@ def _create_output_dist(self) -> None:
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device)

embedding_shard_offsets: List[int] = [
meta.shard_offsets[1] if meta is not None else 0
for meta in embedding_shard_metadata
Expand All @@ -1179,6 +1191,38 @@ def _create_output_dist(self) -> None:
embedding_shard_offsets[i],
),
)

self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
self._uncombined_embedding_dims, permute_indices, self._device
)

def _update_output_dist(self) -> None:
embedding_shard_metadata: List[Optional[ShardMetadata]] = []
# TODO: Optimize to only go through embedding shardings with new ranks
self._output_dists: List[nn.Module] = []
self._embedding_names: List[str] = []
for sharding in self._embedding_shardings:
# TODO: if sharding type of table completely changes, need to regenerate everything
self._embedding_names.extend(sharding.embedding_names())
self._output_dists.append(sharding.create_output_dist(device=self._device))
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())

embedding_shard_offsets: List[int] = [
meta.shard_offsets[1] if meta is not None else 0
for meta in embedding_shard_metadata
]
embedding_name_order: Dict[str, int] = {}
for i, name in enumerate(self._uncombined_embedding_names):
embedding_name_order.setdefault(name, i)

permute_indices = sorted(
range(len(self._uncombined_embedding_names)),
key=lambda i: (
embedding_name_order[self._uncombined_embedding_names[i]],
embedding_shard_offsets[i],
),
)

self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
self._uncombined_embedding_dims, permute_indices, self._device
)
Expand Down Expand Up @@ -1396,13 +1440,119 @@ def compute_and_output_dist(

return awaitable

def update_shards(
self,
changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta
env: ShardingEnv,
device: Optional[torch.device],
) -> None:
"""
Update shards for this module based on the changed_sharding_params. This will:
1. Move current lookup tensors to CPU
2. Purge lookups
3. Call shards_all_2_all containing collective to redistribute tensors
4. Update state_dict and other attributes to reflect new placements and shards
5. Create new lookups, and load in updated state_dict

Args:
changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
table names to their new parameter sharding configs. This should only
contain shards/table names that need to be moved.
env (ShardingEnv): The sharding environment for the module.
device (Optional[torch.device]): The device to place the updated module on.
"""

if env.output_dtensor:
raise RuntimeError("We do not yet support DTensor for resharding yet")
return

current_state = self.state_dict()
# TODO: Save Optimizers

saved_weights = {}
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
for i, lookup in enumerate(self._lookups):
for attribute, tbe_module in lookup.named_modules():
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
# Note: lookup.purge should delete tbe_module and weights
# del tbe_module.weights
# del tbe_module
# pyre-ignore
lookup.purge()

# Deleting all lookups
self._lookups.clear()

local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
module=self,
state_dict=current_state,
device=device, # pyre-ignore
changed_sharding_params=changed_sharding_params,
env=env,
extend_shard_name=self.extend_shard_name,
)

current_state = update_state_dict_post_resharding(
state_dict=current_state,
shard_names_by_src_rank=local_shard_names_by_src_rank,
output_tensor=local_output_tensor,
new_sharding_params=changed_sharding_params,
curr_rank=dist.get_rank(),
extend_shard_name=self.extend_shard_name,
)

for name, param in changed_sharding_params.items():
self.module_sharding_plan[name] = param
# TODO: Support detecting old sharding type when sharding type is changing
for sharding_info in self.sharding_type_to_sharding_infos[
param.sharding_type
]:
if sharding_info.embedding_config.name == name:
sharding_info.param_sharding = param

self._sharding_types: List[str] = list(
self.sharding_type_to_sharding_infos.keys()
)
# TODO: Optimize to update only the changed embedding shardings
self._embedding_shardings: List[
EmbeddingSharding[
EmbeddingShardingContext,
KeyedJaggedTensor,
torch.Tensor,
torch.Tensor,
]
] = [
create_embedding_bag_sharding(
embedding_configs,
env,
device,
permute_embeddings=True,
qcomm_codecs_registry=self.qcomm_codecs_registry,
)
for embedding_configs in self.sharding_type_to_sharding_infos.values()
]

self._create_lookups()
self._update_output_dist()

if env.process_group and dist.get_backend(env.process_group) != "fake":
self._initialize_torch_state(skip_registering=True)

self.load_state_dict(current_state)
return

@property
def fused_optimizer(self) -> KeyedOptimizer:
return self._optim

def create_context(self) -> EmbeddingBagCollectionContext:
return EmbeddingBagCollectionContext()

@staticmethod
def extend_shard_name(shard_name: str) -> str:
return f"embedding_bags.{shard_name}.weight"


class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
"""
Expand Down Expand Up @@ -1435,6 +1585,33 @@ def shardable_parameters(
for name, param in module.embedding_bags.named_parameters()
}

def reshard(
self,
sharded_module: ShardedEmbeddingBagCollection,
changed_shard_to_params: Dict[str, ParameterSharding],
env: ShardingEnv,
device: Optional[torch.device] = None,
) -> ShardedEmbeddingBagCollection:
"""
Updates the sharded module in place based on the changed_shard_to_params
which contains the new ParameterSharding with different shard placements.

Args:
sharded_module (ShardedEmbeddingBagCollection): The module to update
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
table names to their new parameter sharding configs. This should only
contain shards/table names that need to be moved
env (ShardingEnv): The sharding environment
device (Optional[torch.device]): The device to place the updated module on

Returns:
ShardedEmbeddingBagCollection: The updated sharded module
"""

if len(changed_shard_to_params) > 0:
sharded_module.update_shards(changed_shard_to_params, env, device)
return sharded_module

@property
def module_type(self) -> Type[EmbeddingBagCollection]:
return EmbeddingBagCollection
Expand Down
Loading
Loading