diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index e66cb674b..b817f020a 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -129,6 +129,9 @@ def __init__( ) ) self._return_remapped_features: bool = module._return_remapped_features + self._allow_in_place_embed_weight_update: bool = ( + module._allow_in_place_embed_weight_update + ) # pyre-ignore self._table_to_tbe_and_index = {} @@ -202,12 +205,22 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None init_fn = self._embedding_module._table_name_to_config[ table ].init_fn - # Set evicted indices to original init_fn instead of all zeros - # pyre-ignore [29] - table_weight_param[evictions_indices_for_table] = init_fn( - table_weight_param[evictions_indices_for_table] - ) + if self._allow_in_place_embed_weight_update: + # In-place update with .data to bypass PyTorch's autograd tracking. + # This is required for model training with multiple forward passes where the autograd graph + # is already created. Direct tensor modification would trigger PyTorch's in-place operation + # checks and invalidate gradients, while .data allows safe reinitialization of evicted + # embeddings without affecting the computational graph. + # pyre-ignore [29] + table_weight_param.data[evictions_indices_for_table] = init_fn( + table_weight_param[evictions_indices_for_table] + ) + else: + # pyre-ignore [29] + table_weight_param[evictions_indices_for_table] = init_fn( + table_weight_param[evictions_indices_for_table] + ) def compute( self, diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py index 60de369d1..64c3ca14e 100644 --- a/torchrec/distributed/tests/test_mc_embedding.py +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -59,6 +59,7 @@ def __init__( device: torch.device, return_remapped: bool = False, input_hash_size: int = 4000, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__() self._return_remapped = return_remapped @@ -91,6 +92,7 @@ def __init__( embedding_configs=tables, ), return_remapped_features=self._return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, ) ) @@ -242,6 +244,106 @@ def _test_sharding_and_remapping( # noqa C901 # TODO: validate embedding rows, and eviction +def _test_in_place_embd_weight_update( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + initial_state_per_rank: List[Dict[str, torch.Tensor]], + final_state_per_rank: List[Dict[str, torch.Tensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + input_hash_size: int = 4000, + allow_in_place_embed_weight_update: bool = True, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + input_hash_size=input_hash_size, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + initial_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in initial_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in initial_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, initial_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {initial_state_per_rank[rank][postfix]}" + + sharded_sparse_arch.load_state_dict(initial_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + + if not allow_in_place_embed_weight_update: + # Without in-place overwrite the backward pass will fail due to tensor version mismatch + with unittest.TestCase().assertRaisesRegex( + RuntimeError, + "one of the variables needed for gradient computation has been modified by an inplace operation", + ): + loss1.backward() + else: + loss1.backward() + loss2.backward() + final_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in final_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in final_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, final_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}" + + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + def _test_sharding_and_resharding( # noqa C901 tables: List[EmbeddingConfig], rank: int, @@ -1016,3 +1118,166 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None: ), except AssertionError as e: self.assertTrue("0 != 1" in str(e)) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + backend=st.sampled_from(["nccl"]), + allow_in_place_embed_weight_update=st.booleans(), + ) + @settings(deadline=None) + def test_in_place_embd_weight_update( + self, backend: str, allow_in_place_embed_weight_update: bool + ) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + 2, + 2, + 2, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + # TODO: cleanup sorting so more dedugable/logical initial fill + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + + initial_state_per_rank = [ + { + "table_0._mch_remapped_ids_mapping": torch.arange(8, dtype=torch.int64), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_remapped_ids_mapping": torch.arange( + start=8, end=16, dtype=torch.int64 + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + start=16, end=32, dtype=torch.int64 + ), + }, + ] + max_int = torch.iinfo(torch.int64).max + + final_state_per_rank = [ + { + "table_0._mch_sorted_raw_ids": torch.LongTensor( + [1000, 1001, 1002, 1004] + [max_int] * 4 + ), + "table_1._mch_sorted_raw_ids": torch.LongTensor([max_int] * 16), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [3, 4, 5, 6, 0, 1, 2, 7] + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_sorted_raw_ids": torch.LongTensor([2000] + [max_int] * 7), + "table_1._mch_sorted_raw_ids": torch.LongTensor( + [2000, 2001, 2002, 2004] + [max_int] * 12 + ), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [14, 8, 9, 10, 11, 12, 13, 15] + ), + "table_1._mch_remapped_ids_mapping": torch.LongTensor( + [27, 29, 28, 30, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31] + ), + }, + ] + + self._run_multi_process_test( + callable=_test_in_place_embd_weight_update, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + initial_state_per_rank=initial_state_per_rank, + final_state_per_rank=final_state_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend=backend, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) diff --git a/torchrec/distributed/tests/test_mc_embeddingbag.py b/torchrec/distributed/tests/test_mc_embeddingbag.py index e891e8841..a24caf2cc 100644 --- a/torchrec/distributed/tests/test_mc_embeddingbag.py +++ b/torchrec/distributed/tests/test_mc_embeddingbag.py @@ -9,7 +9,7 @@ import copy import unittest -from typing import Dict, List, Optional, Tuple +from typing import Dict, Final, List, Optional, Tuple import torch import torch.nn as nn @@ -43,12 +43,103 @@ from torchrec.test_utils import skip_if_asan_class +# Global constants for testing ShardedManagedCollisionEmbeddingBagCollection + +WORLD_SIZE = 2 + +# Input KeyedJaggedTensors for each rank in distributed tests +embedding_bag_config: Final[List[EmbeddingBagConfig]] = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), +] + +# Expected remapped outputs per iteration per rank for validation +kjt_input_per_rank: Final[List[KeyedJaggedTensor]] = [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + 1, + 1, + 1, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), +] + +kjt_out_per_iter_per_rank: Final[List[List[KeyedJaggedTensor]]] = [ + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ], + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ], +] + + class SparseArch(nn.Module): def __init__( self, tables: List[EmbeddingBagConfig], device: torch.device, return_remapped: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__() self._return_remapped = return_remapped @@ -81,6 +172,7 @@ def __init__( embedding_configs=tables, ), return_remapped_features=self._return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, ) ) @@ -268,6 +360,87 @@ def _test_sharding_and_remapping( # noqa C901 # TODO: validate embedding rows, and eviction +def _test_in_place_embd_weight_update( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + allow_in_place_embed_weight_update: bool = True, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_0" + ].weight, + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_1" + ].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ebc, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ebc": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + test_state_dict = sharded_sparse_arch.state_dict() + sharded_sparse_arch.load_state_dict(test_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + if not allow_in_place_embed_weight_update: + # Without in-place overwrite the backward pass will fail due to tensor version mismatch + with unittest.TestCase().assertRaisesRegex( + RuntimeError, + "one of the variables needed for gradient computation has been modified by an inplace operation", + ): + loss1.backward() + else: + loss1.backward() + loss2.backward() + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + @skip_if_asan_class class ShardedMCEmbeddingBagCollectionParallelTest(MultiProcessTestBase): @unittest.skipIf( @@ -311,22 +484,6 @@ def test_uneven_sharding(self, backend: str) -> None: @given(backend=st.sampled_from(["nccl"])) @settings(deadline=None) def test_even_sharding(self, backend: str) -> None: - WORLD_SIZE = 2 - - embedding_bag_config = [ - EmbeddingBagConfig( - name="table_0", - feature_names=["feature_0"], - embedding_dim=8, - num_embeddings=16, - ), - EmbeddingBagConfig( - name="table_1", - feature_names=["feature_1"], - embedding_dim=8, - num_embeddings=32, - ), - ] self._run_multi_process_test( callable=_test_sharding, @@ -344,99 +501,33 @@ def test_even_sharding(self, backend: str) -> None: @given(backend=st.sampled_from(["nccl"])) @settings(deadline=None) def test_sharding_zch_mc_ebc(self, backend: str) -> None: - - WORLD_SIZE = 2 - - embedding_bag_config = [ - EmbeddingBagConfig( - name="table_0", - feature_names=["feature_0"], - embedding_dim=8, - num_embeddings=16, - ), - EmbeddingBagConfig( - name="table_1", - feature_names=["feature_1"], - embedding_dim=8, - num_embeddings=32, - ), - ] - - kjt_input_per_rank = [ # noqa - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1", "feature_2"], - values=torch.LongTensor( - [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), - weights=None, - ), - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1", "feature_2"], - values=torch.LongTensor( - [ - 1000, - 1002, - 1004, - 2000, - 2002, - 2004, - 1, - 1, - 1, - ], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), - weights=None, - ), - ] - - kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] - kjt_out_per_iter_per_rank.append( - [ - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [7, 15, 7, 31, 31, 31], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [7, 7, 7, 31, 31, 31], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - ] + self._run_multi_process_test( + callable=_test_sharding_and_remapping, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_bag_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + sharder=ManagedCollisionEmbeddingBagCollectionSharder(), + backend=backend, ) - # TODO: cleanup sorting so more dedugable/logical initial fill - kjt_out_per_iter_per_rank.append( - [ - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [3, 14, 4, 27, 29, 28], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [3, 5, 6, 27, 28, 30], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - ] - ) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + backend=st.sampled_from(["nccl"]), + allow_in_place_embed_weight_update=st.booleans(), + ) + @settings(deadline=None) + def test_in_place_embd_weight_update( + self, backend: str, allow_in_place_embed_weight_update: bool + ) -> None: self._run_multi_process_test( - callable=_test_sharding_and_remapping, + callable=_test_in_place_embd_weight_update, output_keys=["feature_0", "feature_1"], world_size=WORLD_SIZE, tables=embedding_bag_config, @@ -444,4 +535,5 @@ def test_sharding_zch_mc_ebc(self, backend: str) -> None: kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, sharder=ManagedCollisionEmbeddingBagCollectionSharder(), backend=backend, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, ) diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 834ad667c..6e7850dba 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -39,6 +39,9 @@ class BaseManagedCollisionEmbeddingCollection(nn.Module): managed_collision_modules: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings + allow_in_place_embed_weight_update(bool): Enables in-place update of embedding + weights on eviction. When enabled, this flag allows updates to embedding + weights without modifying the autograd graph. """ @@ -47,10 +50,12 @@ def __init__( embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection], managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__() self._managed_collision_collection = managed_collision_collection self._return_remapped_features = return_remapped_features + self._allow_in_place_embed_weight_update = allow_in_place_embed_weight_update self._embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection] = ( embedding_module ) @@ -97,10 +102,13 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio For details of input and output types, see EmbeddingCollection Args: - embedding_module: EmbeddingCollection to lookup embeddings - managed_collision_modules: Dict of managed collision modules + embedding_collection: EmbeddingCollection to lookup embeddings + managed_collision_collection: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings + allow_in_place_embed_weight_update(bool): enable in place update of embedding + weights on evict. This flag when enabled will allow update embedding + weights without modifying of autograd graph. """ @@ -109,9 +117,13 @@ def __init__( embedding_collection: EmbeddingCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__( - embedding_collection, managed_collision_collection, return_remapped_features + embedding_collection, + managed_collision_collection, + return_remapped_features, + allow_in_place_embed_weight_update, ) # For consistency with embedding bag collection @@ -132,6 +144,10 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec managed_collision_modules: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings + allow_in_place_embed_weight_update(bool): Enables in-place update of embedding + weights on eviction. When enabled, this flag allows updates to embedding + weights without modifying the autograd graph. + """ @@ -140,11 +156,13 @@ def __init__( embedding_bag_collection: EmbeddingBagCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__( embedding_bag_collection, managed_collision_collection, return_remapped_features, + allow_in_place_embed_weight_update, ) # For backwards compat, as references existed in tests