Skip to content

Commit c7fc837

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Multi forward MCH eviction fix (#2836)
Summary: Pull Request resolved: #2836 ## Issue: Direct tensor modification during training with multiple forward passes breaks PyTorch's autograd graph, causing "one of the variables needed for gradient computation has been modified by an inplace operation" runtime error. ## Solution: Use in-place updates with .data accessor to safely reinitialize evicted embeddings without invalidating gradient computation. Reviewed By: dstaay-fb Differential Revision: D71491003
1 parent f0ae23d commit c7fc837

File tree

4 files changed

+503
-114
lines changed

4 files changed

+503
-114
lines changed

torchrec/distributed/mc_embedding_modules.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def __init__(
129129
)
130130
)
131131
self._return_remapped_features: bool = module._return_remapped_features
132+
self._allow_in_place_embed_weight_update: bool = (
133+
module._allow_in_place_embed_weight_update
134+
)
132135

133136
# pyre-ignore
134137
self._table_to_tbe_and_index = {}
@@ -202,12 +205,22 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None
202205
init_fn = self._embedding_module._table_name_to_config[
203206
table
204207
].init_fn
205-
206208
# Set evicted indices to original init_fn instead of all zeros
207-
# pyre-ignore [29]
208-
table_weight_param[evictions_indices_for_table] = init_fn(
209-
table_weight_param[evictions_indices_for_table]
210-
)
209+
if self._allow_in_place_embed_weight_update:
210+
# In-place update with .data to bypass PyTorch's autograd tracking.
211+
# This is required for model training with multiple forward passes where the autograd graph
212+
# is already created. Direct tensor modification would trigger PyTorch's in-place operation
213+
# checks and invalidate gradients, while .data allows safe reinitialization of evicted
214+
# embeddings without affecting the computational graph.
215+
# pyre-ignore [29]
216+
table_weight_param.data[evictions_indices_for_table] = init_fn(
217+
table_weight_param[evictions_indices_for_table]
218+
)
219+
else:
220+
# pyre-ignore [29]
221+
table_weight_param[evictions_indices_for_table] = init_fn(
222+
table_weight_param[evictions_indices_for_table]
223+
)
211224

212225
def compute(
213226
self,

torchrec/distributed/tests/test_mc_embedding.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
device: torch.device,
6060
return_remapped: bool = False,
6161
input_hash_size: int = 4000,
62+
allow_in_place_embed_weight_update: bool = False,
6263
) -> None:
6364
super().__init__()
6465
self._return_remapped = return_remapped
@@ -91,6 +92,7 @@ def __init__(
9192
embedding_configs=tables,
9293
),
9394
return_remapped_features=self._return_remapped,
95+
allow_in_place_embed_weight_update=allow_in_place_embed_weight_update,
9496
)
9597
)
9698

@@ -242,6 +244,106 @@ def _test_sharding_and_remapping( # noqa C901
242244
# TODO: validate embedding rows, and eviction
243245

244246

247+
def _test_in_place_embd_weight_update( # noqa C901
248+
output_keys: List[str],
249+
tables: List[EmbeddingConfig],
250+
rank: int,
251+
world_size: int,
252+
kjt_input_per_rank: List[KeyedJaggedTensor],
253+
kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]],
254+
initial_state_per_rank: List[Dict[str, torch.Tensor]],
255+
final_state_per_rank: List[Dict[str, torch.Tensor]],
256+
sharder: ModuleSharder[nn.Module],
257+
backend: str,
258+
local_size: Optional[int] = None,
259+
input_hash_size: int = 4000,
260+
allow_in_place_embed_weight_update: bool = True,
261+
) -> None:
262+
263+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
264+
kjt_input = kjt_input_per_rank[rank].to(ctx.device)
265+
kjt_out_per_iter = [
266+
kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank
267+
]
268+
return_remapped: bool = True
269+
sparse_arch = SparseArch(
270+
tables,
271+
torch.device("meta"),
272+
return_remapped=return_remapped,
273+
input_hash_size=input_hash_size,
274+
allow_in_place_embed_weight_update=allow_in_place_embed_weight_update,
275+
)
276+
apply_optimizer_in_backward(
277+
RowWiseAdagrad,
278+
[
279+
sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight,
280+
sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight,
281+
],
282+
{"lr": 0.01},
283+
)
284+
module_sharding_plan = construct_module_sharding_plan(
285+
sparse_arch._mc_ec,
286+
per_param_sharding={"table_0": row_wise(), "table_1": row_wise()},
287+
local_size=local_size,
288+
world_size=world_size,
289+
device_type="cuda" if torch.cuda.is_available() else "cpu",
290+
sharder=sharder,
291+
)
292+
293+
sharded_sparse_arch = _shard_modules(
294+
module=copy.deepcopy(sparse_arch),
295+
plan=ShardingPlan({"_mc_ec": module_sharding_plan}),
296+
# pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got
297+
# `Optional[ProcessGroup]`.
298+
env=ShardingEnv.from_process_group(ctx.pg),
299+
sharders=[sharder],
300+
device=ctx.device,
301+
)
302+
303+
initial_state_dict = sharded_sparse_arch.state_dict()
304+
for key, sharded_tensor in initial_state_dict.items():
305+
postfix = ".".join(key.split(".")[-2:])
306+
if postfix in initial_state_per_rank[ctx.rank]:
307+
tensor = sharded_tensor.local_shards()[0].tensor.cpu()
308+
assert torch.equal(
309+
tensor, initial_state_per_rank[ctx.rank][postfix]
310+
), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {initial_state_per_rank[rank][postfix]}"
311+
312+
sharded_sparse_arch.load_state_dict(initial_state_dict)
313+
314+
# sharded model
315+
# each rank gets a subbatch
316+
loss1, remapped_ids1 = sharded_sparse_arch(kjt_input)
317+
loss2, remapped_ids2 = sharded_sparse_arch(kjt_input)
318+
319+
if not allow_in_place_embed_weight_update:
320+
# Without in-place overwrite the backward pass will fail due to tensor version mismatch
321+
with unittest.TestCase().assertRaisesRegex(
322+
RuntimeError,
323+
"one of the variables needed for gradient computation has been modified by an inplace operation",
324+
):
325+
loss1.backward()
326+
else:
327+
loss1.backward()
328+
loss2.backward()
329+
final_state_dict = sharded_sparse_arch.state_dict()
330+
for key, sharded_tensor in final_state_dict.items():
331+
postfix = ".".join(key.split(".")[-2:])
332+
if postfix in final_state_per_rank[ctx.rank]:
333+
tensor = sharded_tensor.local_shards()[0].tensor.cpu()
334+
assert torch.equal(
335+
tensor, final_state_per_rank[ctx.rank][postfix]
336+
), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}"
337+
338+
remapped_ids = [remapped_ids1, remapped_ids2]
339+
for key in output_keys:
340+
for i, kjt_out in enumerate(kjt_out_per_iter):
341+
assert torch.equal(
342+
remapped_ids[i][key].values(),
343+
kjt_out[key].values(),
344+
), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}"
345+
346+
245347
def _test_sharding_and_resharding( # noqa C901
246348
tables: List[EmbeddingConfig],
247349
rank: int,
@@ -1016,3 +1118,166 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None:
10161118
),
10171119
except AssertionError as e:
10181120
self.assertTrue("0 != 1" in str(e))
1121+
1122+
@unittest.skipIf(
1123+
torch.cuda.device_count() <= 1,
1124+
"Not enough GPUs, this test requires at least two GPUs",
1125+
)
1126+
# pyre-ignore
1127+
@given(
1128+
backend=st.sampled_from(["nccl"]),
1129+
allow_in_place_embed_weight_update=st.booleans(),
1130+
)
1131+
@settings(deadline=None)
1132+
def test_in_place_embd_weight_update(
1133+
self, backend: str, allow_in_place_embed_weight_update: bool
1134+
) -> None:
1135+
1136+
WORLD_SIZE = 2
1137+
1138+
embedding_config = [
1139+
EmbeddingConfig(
1140+
name="table_0",
1141+
feature_names=["feature_0"],
1142+
embedding_dim=8,
1143+
num_embeddings=16,
1144+
),
1145+
EmbeddingConfig(
1146+
name="table_1",
1147+
feature_names=["feature_1"],
1148+
embedding_dim=8,
1149+
num_embeddings=32,
1150+
),
1151+
]
1152+
1153+
kjt_input_per_rank = [ # noqa
1154+
KeyedJaggedTensor.from_lengths_sync(
1155+
keys=["feature_0", "feature_1", "feature_2"],
1156+
values=torch.LongTensor(
1157+
[1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1],
1158+
),
1159+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
1160+
weights=None,
1161+
),
1162+
KeyedJaggedTensor.from_lengths_sync(
1163+
keys=["feature_0", "feature_1", "feature_2"],
1164+
values=torch.LongTensor(
1165+
[
1166+
1000,
1167+
1002,
1168+
1004,
1169+
2000,
1170+
2002,
1171+
2004,
1172+
2,
1173+
2,
1174+
2,
1175+
],
1176+
),
1177+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
1178+
weights=None,
1179+
),
1180+
]
1181+
1182+
kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = []
1183+
kjt_out_per_iter_per_rank.append(
1184+
[
1185+
KeyedJaggedTensor.from_lengths_sync(
1186+
keys=["feature_0", "feature_1"],
1187+
values=torch.LongTensor(
1188+
[7, 15, 7, 31, 31, 31],
1189+
),
1190+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1191+
weights=None,
1192+
),
1193+
KeyedJaggedTensor.from_lengths_sync(
1194+
keys=["feature_0", "feature_1"],
1195+
values=torch.LongTensor(
1196+
[7, 7, 7, 31, 31, 31],
1197+
),
1198+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1199+
weights=None,
1200+
),
1201+
]
1202+
)
1203+
# TODO: cleanup sorting so more dedugable/logical initial fill
1204+
1205+
kjt_out_per_iter_per_rank.append(
1206+
[
1207+
KeyedJaggedTensor.from_lengths_sync(
1208+
keys=["feature_0", "feature_1"],
1209+
values=torch.LongTensor(
1210+
[3, 14, 4, 27, 29, 28],
1211+
),
1212+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1213+
weights=None,
1214+
),
1215+
KeyedJaggedTensor.from_lengths_sync(
1216+
keys=["feature_0", "feature_1"],
1217+
values=torch.LongTensor(
1218+
[3, 5, 6, 27, 28, 30],
1219+
),
1220+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1221+
weights=None,
1222+
),
1223+
]
1224+
)
1225+
1226+
initial_state_per_rank = [
1227+
{
1228+
"table_0._mch_remapped_ids_mapping": torch.arange(8, dtype=torch.int64),
1229+
"table_1._mch_remapped_ids_mapping": torch.arange(
1230+
16, dtype=torch.int64
1231+
),
1232+
},
1233+
{
1234+
"table_0._mch_remapped_ids_mapping": torch.arange(
1235+
start=8, end=16, dtype=torch.int64
1236+
),
1237+
"table_1._mch_remapped_ids_mapping": torch.arange(
1238+
start=16, end=32, dtype=torch.int64
1239+
),
1240+
},
1241+
]
1242+
max_int = torch.iinfo(torch.int64).max
1243+
1244+
final_state_per_rank = [
1245+
{
1246+
"table_0._mch_sorted_raw_ids": torch.LongTensor(
1247+
[1000, 1001, 1002, 1004] + [max_int] * 4
1248+
),
1249+
"table_1._mch_sorted_raw_ids": torch.LongTensor([max_int] * 16),
1250+
"table_0._mch_remapped_ids_mapping": torch.LongTensor(
1251+
[3, 4, 5, 6, 0, 1, 2, 7]
1252+
),
1253+
"table_1._mch_remapped_ids_mapping": torch.arange(
1254+
16, dtype=torch.int64
1255+
),
1256+
},
1257+
{
1258+
"table_0._mch_sorted_raw_ids": torch.LongTensor([2000] + [max_int] * 7),
1259+
"table_1._mch_sorted_raw_ids": torch.LongTensor(
1260+
[2000, 2001, 2002, 2004] + [max_int] * 12
1261+
),
1262+
"table_0._mch_remapped_ids_mapping": torch.LongTensor(
1263+
[14, 8, 9, 10, 11, 12, 13, 15]
1264+
),
1265+
"table_1._mch_remapped_ids_mapping": torch.LongTensor(
1266+
[27, 29, 28, 30, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31]
1267+
),
1268+
},
1269+
]
1270+
1271+
self._run_multi_process_test(
1272+
callable=_test_in_place_embd_weight_update,
1273+
output_keys=["feature_0", "feature_1"],
1274+
world_size=WORLD_SIZE,
1275+
tables=embedding_config,
1276+
kjt_input_per_rank=kjt_input_per_rank,
1277+
kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank,
1278+
initial_state_per_rank=initial_state_per_rank,
1279+
final_state_per_rank=final_state_per_rank,
1280+
sharder=ManagedCollisionEmbeddingCollectionSharder(),
1281+
backend=backend,
1282+
allow_in_place_embed_weight_update=allow_in_place_embed_weight_update,
1283+
)

0 commit comments

Comments
 (0)