Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions torchrec/distributed/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
)
)
self._return_remapped_features: bool = module._return_remapped_features
self._allow_in_place_embed_weight_update: bool = (
module._allow_in_place_embed_weight_update
)

# pyre-ignore
self._table_to_tbe_and_index = {}
Expand Down Expand Up @@ -202,12 +205,22 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None
init_fn = self._embedding_module._table_name_to_config[
table
].init_fn

# Set evicted indices to original init_fn instead of all zeros
# pyre-ignore [29]
table_weight_param[evictions_indices_for_table] = init_fn(
table_weight_param[evictions_indices_for_table]
)
if self._allow_in_place_embed_weight_update:
# In-place update with .data to bypass PyTorch's autograd tracking.
# This is required for model training with multiple forward passes where the autograd graph
# is already created. Direct tensor modification would trigger PyTorch's in-place operation
# checks and invalidate gradients, while .data allows safe reinitialization of evicted
# embeddings without affecting the computational graph.
# pyre-ignore [29]
table_weight_param.data[evictions_indices_for_table] = init_fn(
table_weight_param[evictions_indices_for_table]
)
else:
# pyre-ignore [29]
table_weight_param[evictions_indices_for_table] = init_fn(
table_weight_param[evictions_indices_for_table]
)

def compute(
self,
Expand Down
265 changes: 265 additions & 0 deletions torchrec/distributed/tests/test_mc_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
device: torch.device,
return_remapped: bool = False,
input_hash_size: int = 4000,
allow_in_place_embed_weight_update: bool = False,
) -> None:
super().__init__()
self._return_remapped = return_remapped
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
embedding_configs=tables,
),
return_remapped_features=self._return_remapped,
allow_in_place_embed_weight_update=allow_in_place_embed_weight_update,
)
)

Expand Down Expand Up @@ -242,6 +244,106 @@ def _test_sharding_and_remapping( # noqa C901
# TODO: validate embedding rows, and eviction


def _test_in_place_embd_weight_update( # noqa C901
output_keys: List[str],
tables: List[EmbeddingConfig],
rank: int,
world_size: int,
kjt_input_per_rank: List[KeyedJaggedTensor],
kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]],
initial_state_per_rank: List[Dict[str, torch.Tensor]],
final_state_per_rank: List[Dict[str, torch.Tensor]],
sharder: ModuleSharder[nn.Module],
backend: str,
local_size: Optional[int] = None,
input_hash_size: int = 4000,
allow_in_place_embed_weight_update: bool = True,
) -> None:

with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
kjt_input = kjt_input_per_rank[rank].to(ctx.device)
kjt_out_per_iter = [
kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank
]
return_remapped: bool = True
sparse_arch = SparseArch(
tables,
torch.device("meta"),
return_remapped=return_remapped,
input_hash_size=input_hash_size,
allow_in_place_embed_weight_update=allow_in_place_embed_weight_update,
)
apply_optimizer_in_backward(
RowWiseAdagrad,
[
sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight,
sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight,
],
{"lr": 0.01},
)
module_sharding_plan = construct_module_sharding_plan(
sparse_arch._mc_ec,
per_param_sharding={"table_0": row_wise(), "table_1": row_wise()},
local_size=local_size,
world_size=world_size,
device_type="cuda" if torch.cuda.is_available() else "cpu",
sharder=sharder,
)

sharded_sparse_arch = _shard_modules(
module=copy.deepcopy(sparse_arch),
plan=ShardingPlan({"_mc_ec": module_sharding_plan}),
# pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
env=ShardingEnv.from_process_group(ctx.pg),
sharders=[sharder],
device=ctx.device,
)

initial_state_dict = sharded_sparse_arch.state_dict()
for key, sharded_tensor in initial_state_dict.items():
postfix = ".".join(key.split(".")[-2:])
if postfix in initial_state_per_rank[ctx.rank]:
tensor = sharded_tensor.local_shards()[0].tensor.cpu()
assert torch.equal(
tensor, initial_state_per_rank[ctx.rank][postfix]
), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {initial_state_per_rank[rank][postfix]}"

sharded_sparse_arch.load_state_dict(initial_state_dict)

# sharded model
# each rank gets a subbatch
loss1, remapped_ids1 = sharded_sparse_arch(kjt_input)
loss2, remapped_ids2 = sharded_sparse_arch(kjt_input)

if not allow_in_place_embed_weight_update:
# Without in-place overwrite the backward pass will fail due to tensor version mismatch
with unittest.TestCase().assertRaisesRegex(
RuntimeError,
"one of the variables needed for gradient computation has been modified by an inplace operation",
):
loss1.backward()
else:
loss1.backward()
loss2.backward()
final_state_dict = sharded_sparse_arch.state_dict()
for key, sharded_tensor in final_state_dict.items():
postfix = ".".join(key.split(".")[-2:])
if postfix in final_state_per_rank[ctx.rank]:
tensor = sharded_tensor.local_shards()[0].tensor.cpu()
assert torch.equal(
tensor, final_state_per_rank[ctx.rank][postfix]
), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}"

remapped_ids = [remapped_ids1, remapped_ids2]
for key in output_keys:
for i, kjt_out in enumerate(kjt_out_per_iter):
assert torch.equal(
remapped_ids[i][key].values(),
kjt_out[key].values(),
), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}"


def _test_sharding_and_resharding( # noqa C901
tables: List[EmbeddingConfig],
rank: int,
Expand Down Expand Up @@ -1016,3 +1118,166 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None:
),
except AssertionError as e:
self.assertTrue("0 != 1" in str(e))

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
# pyre-ignore
@given(
backend=st.sampled_from(["nccl"]),
allow_in_place_embed_weight_update=st.booleans(),
)
@settings(deadline=None)
def test_in_place_embd_weight_update(
self, backend: str, allow_in_place_embed_weight_update: bool
) -> None:

WORLD_SIZE = 2

embedding_config = [
EmbeddingConfig(
name="table_0",
feature_names=["feature_0"],
embedding_dim=8,
num_embeddings=16,
),
EmbeddingConfig(
name="table_1",
feature_names=["feature_1"],
embedding_dim=8,
num_embeddings=32,
),
]

kjt_input_per_rank = [ # noqa
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1", "feature_2"],
values=torch.LongTensor(
[1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
weights=None,
),
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1", "feature_2"],
values=torch.LongTensor(
[
1000,
1002,
1004,
2000,
2002,
2004,
2,
2,
2,
],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
weights=None,
),
]

kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = []
kjt_out_per_iter_per_rank.append(
[
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
values=torch.LongTensor(
[7, 15, 7, 31, 31, 31],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
weights=None,
),
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
values=torch.LongTensor(
[7, 7, 7, 31, 31, 31],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
weights=None,
),
]
)
# TODO: cleanup sorting so more dedugable/logical initial fill

kjt_out_per_iter_per_rank.append(
[
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
values=torch.LongTensor(
[3, 14, 4, 27, 29, 28],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
weights=None,
),
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
values=torch.LongTensor(
[3, 5, 6, 27, 28, 30],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
weights=None,
),
]
)

initial_state_per_rank = [
{
"table_0._mch_remapped_ids_mapping": torch.arange(8, dtype=torch.int64),
"table_1._mch_remapped_ids_mapping": torch.arange(
16, dtype=torch.int64
),
},
{
"table_0._mch_remapped_ids_mapping": torch.arange(
start=8, end=16, dtype=torch.int64
),
"table_1._mch_remapped_ids_mapping": torch.arange(
start=16, end=32, dtype=torch.int64
),
},
]
max_int = torch.iinfo(torch.int64).max

final_state_per_rank = [
{
"table_0._mch_sorted_raw_ids": torch.LongTensor(
[1000, 1001, 1002, 1004] + [max_int] * 4
),
"table_1._mch_sorted_raw_ids": torch.LongTensor([max_int] * 16),
"table_0._mch_remapped_ids_mapping": torch.LongTensor(
[3, 4, 5, 6, 0, 1, 2, 7]
),
"table_1._mch_remapped_ids_mapping": torch.arange(
16, dtype=torch.int64
),
},
{
"table_0._mch_sorted_raw_ids": torch.LongTensor([2000] + [max_int] * 7),
"table_1._mch_sorted_raw_ids": torch.LongTensor(
[2000, 2001, 2002, 2004] + [max_int] * 12
),
"table_0._mch_remapped_ids_mapping": torch.LongTensor(
[14, 8, 9, 10, 11, 12, 13, 15]
),
"table_1._mch_remapped_ids_mapping": torch.LongTensor(
[27, 29, 28, 30, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31]
),
},
]

self._run_multi_process_test(
callable=_test_in_place_embd_weight_update,
output_keys=["feature_0", "feature_1"],
world_size=WORLD_SIZE,
tables=embedding_config,
kjt_input_per_rank=kjt_input_per_rank,
kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank,
initial_state_per_rank=initial_state_per_rank,
final_state_per_rank=final_state_per_rank,
sharder=ManagedCollisionEmbeddingCollectionSharder(),
backend=backend,
allow_in_place_embed_weight_update=allow_in_place_embed_weight_update,
)
Loading
Loading