Skip to content

Commit de77690

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Multi forward MCH eviction fix (#2836)
Summary: ## Issue: Direct tensor modification during training with multiple forward passes breaks PyTorch's autograd graph, causing "one of the variables needed for gradient computation has been modified by an inplace operation" runtime error. ## Solution: Use in-place updates with .data accessor to safely reinitialize evicted embeddings without invalidating gradient computation. Differential Revision: D71491003
1 parent 7652c5d commit de77690

File tree

4 files changed

+486
-111
lines changed

4 files changed

+486
-111
lines changed

torchrec/distributed/mc_embedding_modules.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def __init__(
129129
)
130130
)
131131
self._return_remapped_features: bool = module._return_remapped_features
132+
self._enable_in_place_data_overwrite: bool = (
133+
module._enable_in_place_data_overwrite
134+
)
132135

133136
# pyre-ignore
134137
self._table_to_tbe_and_index = {}
@@ -204,10 +207,21 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None
204207
].init_fn
205208

206209
# Set evicted indices to original init_fn instead of all zeros
207-
# pyre-ignore [29]
208-
table_weight_param[evictions_indices_for_table] = init_fn(
209-
table_weight_param[evictions_indices_for_table]
210-
)
210+
if self._enable_in_place_data_overwrite:
211+
# In-place update with .data to bypass PyTorch's autograd tracking.
212+
# This is required for model training with multiple forward passes where the autograd graph
213+
# is already created. Direct tensor modification would trigger PyTorch's in-place operation
214+
# checks and invalidate gradients, while .data allows safe reinitialization of evicted
215+
# embeddings without affecting the computational graph.
216+
# pyre-ignore [29]
217+
table_weight_param.data[evictions_indices_for_table] = init_fn(
218+
table_weight_param[evictions_indices_for_table]
219+
)
220+
else:
221+
# pyre-ignore [29]
222+
table_weight_param[evictions_indices_for_table] = init_fn(
223+
table_weight_param[evictions_indices_for_table]
224+
)
211225

212226
def compute(
213227
self,

torchrec/distributed/tests/test_mc_embedding.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
device: torch.device,
6060
return_remapped: bool = False,
6161
input_hash_size: int = 4000,
62+
enable_in_place_data_overwrite: bool = False,
6263
) -> None:
6364
super().__init__()
6465
self._return_remapped = return_remapped
@@ -91,6 +92,7 @@ def __init__(
9192
embedding_configs=tables,
9293
),
9394
return_remapped_features=self._return_remapped,
95+
enable_in_place_data_overwrite=enable_in_place_data_overwrite,
9496
)
9597
)
9698

@@ -242,6 +244,104 @@ def _test_sharding_and_remapping( # noqa C901
242244
# TODO: validate embedding rows, and eviction
243245

244246

247+
def _test_in_place_data_overwrite( # noqa C901
248+
output_keys: List[str],
249+
tables: List[EmbeddingConfig],
250+
rank: int,
251+
world_size: int,
252+
kjt_input_per_rank: List[KeyedJaggedTensor],
253+
kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]],
254+
initial_state_per_rank: List[Dict[str, torch.Tensor]],
255+
final_state_per_rank: List[Dict[str, torch.Tensor]],
256+
sharder: ModuleSharder[nn.Module],
257+
backend: str,
258+
local_size: Optional[int] = None,
259+
input_hash_size: int = 4000,
260+
enable_in_place_data_overwrite: bool = True,
261+
) -> None:
262+
263+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
264+
kjt_input = kjt_input_per_rank[rank].to(ctx.device)
265+
kjt_out_per_iter = [
266+
kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank
267+
]
268+
return_remapped: bool = True
269+
sparse_arch = SparseArch(
270+
tables,
271+
torch.device("meta"),
272+
return_remapped=return_remapped,
273+
input_hash_size=input_hash_size,
274+
)
275+
276+
apply_optimizer_in_backward(
277+
RowWiseAdagrad,
278+
[
279+
sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight,
280+
sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight,
281+
],
282+
{"lr": 0.01},
283+
)
284+
module_sharding_plan = construct_module_sharding_plan(
285+
sparse_arch._mc_ec,
286+
per_param_sharding={"table_0": row_wise(), "table_1": row_wise()},
287+
local_size=local_size,
288+
world_size=world_size,
289+
device_type="cuda" if torch.cuda.is_available() else "cpu",
290+
sharder=sharder,
291+
)
292+
293+
sharded_sparse_arch = _shard_modules(
294+
module=copy.deepcopy(sparse_arch),
295+
plan=ShardingPlan({"_mc_ec": module_sharding_plan}),
296+
# pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got
297+
# `Optional[ProcessGroup]`.
298+
env=ShardingEnv.from_process_group(ctx.pg),
299+
sharders=[sharder],
300+
device=ctx.device,
301+
)
302+
303+
initial_state_dict = sharded_sparse_arch.state_dict()
304+
for key, sharded_tensor in initial_state_dict.items():
305+
postfix = ".".join(key.split(".")[-2:])
306+
if postfix in initial_state_per_rank[ctx.rank]:
307+
tensor = sharded_tensor.local_shards()[0].tensor.cpu()
308+
assert torch.equal(
309+
tensor, initial_state_per_rank[ctx.rank][postfix]
310+
), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {initial_state_per_rank[rank][postfix]}"
311+
312+
sharded_sparse_arch.load_state_dict(initial_state_dict)
313+
314+
# sharded model
315+
# each rank gets a subbatch
316+
loss1, remapped_ids1 = sharded_sparse_arch(kjt_input)
317+
loss2, remapped_ids2 = sharded_sparse_arch(kjt_input)
318+
if not enable_in_place_data_overwrite:
319+
# Without in-place overwrite the backward pass will fail due to tensor version mismatch
320+
with unittest.TestCase().assertRaisesRegex(
321+
RuntimeError,
322+
"one of the variables needed for gradient computation has been modified by an inplace operation",
323+
):
324+
loss1.backward()
325+
loss2.backward()
326+
327+
final_state_dict = sharded_sparse_arch.state_dict()
328+
for key, sharded_tensor in final_state_dict.items():
329+
postfix = ".".join(key.split(".")[-2:])
330+
if postfix in final_state_per_rank[ctx.rank]:
331+
tensor = sharded_tensor.local_shards()[0].tensor.cpu()
332+
assert torch.equal(
333+
tensor, final_state_per_rank[ctx.rank][postfix]
334+
), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}"
335+
336+
remapped_ids = [remapped_ids1, remapped_ids2]
337+
for key in output_keys:
338+
for i, kjt_out in enumerate(kjt_out_per_iter):
339+
assert torch.equal(
340+
remapped_ids[i][key].values(),
341+
kjt_out[key].values(),
342+
), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}"
343+
344+
245345
def _test_sharding_and_resharding( # noqa C901
246346
tables: List[EmbeddingConfig],
247347
rank: int,
@@ -1016,3 +1116,165 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None:
10161116
),
10171117
except AssertionError as e:
10181118
self.assertTrue("0 != 1" in str(e))
1119+
1120+
@unittest.skipIf(
1121+
torch.cuda.device_count() <= 1,
1122+
"Not enough GPUs, this test requires at least two GPUs",
1123+
)
1124+
# pyre-ignore
1125+
@given(
1126+
backend=st.sampled_from(["nccl"]), enable_in_place_data_overwrite=st.booleans()
1127+
)
1128+
@settings(deadline=None)
1129+
def test_in_place_data_overwrite(
1130+
self, backend: str, enable_in_place_data_overwrite: bool
1131+
) -> None:
1132+
1133+
WORLD_SIZE = 2
1134+
1135+
embedding_config = [
1136+
EmbeddingConfig(
1137+
name="table_0",
1138+
feature_names=["feature_0"],
1139+
embedding_dim=8,
1140+
num_embeddings=16,
1141+
),
1142+
EmbeddingConfig(
1143+
name="table_1",
1144+
feature_names=["feature_1"],
1145+
embedding_dim=8,
1146+
num_embeddings=32,
1147+
),
1148+
]
1149+
1150+
kjt_input_per_rank = [ # noqa
1151+
KeyedJaggedTensor.from_lengths_sync(
1152+
keys=["feature_0", "feature_1", "feature_2"],
1153+
values=torch.LongTensor(
1154+
[1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1],
1155+
),
1156+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
1157+
weights=None,
1158+
),
1159+
KeyedJaggedTensor.from_lengths_sync(
1160+
keys=["feature_0", "feature_1", "feature_2"],
1161+
values=torch.LongTensor(
1162+
[
1163+
1000,
1164+
1002,
1165+
1004,
1166+
2000,
1167+
2002,
1168+
2004,
1169+
2,
1170+
2,
1171+
2,
1172+
],
1173+
),
1174+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
1175+
weights=None,
1176+
),
1177+
]
1178+
1179+
kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = []
1180+
kjt_out_per_iter_per_rank.append(
1181+
[
1182+
KeyedJaggedTensor.from_lengths_sync(
1183+
keys=["feature_0", "feature_1"],
1184+
values=torch.LongTensor(
1185+
[7, 15, 7, 31, 31, 31],
1186+
),
1187+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1188+
weights=None,
1189+
),
1190+
KeyedJaggedTensor.from_lengths_sync(
1191+
keys=["feature_0", "feature_1"],
1192+
values=torch.LongTensor(
1193+
[7, 7, 7, 31, 31, 31],
1194+
),
1195+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1196+
weights=None,
1197+
),
1198+
]
1199+
)
1200+
# TODO: cleanup sorting so more dedugable/logical initial fill
1201+
1202+
kjt_out_per_iter_per_rank.append(
1203+
[
1204+
KeyedJaggedTensor.from_lengths_sync(
1205+
keys=["feature_0", "feature_1"],
1206+
values=torch.LongTensor(
1207+
[3, 14, 4, 27, 29, 28],
1208+
),
1209+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1210+
weights=None,
1211+
),
1212+
KeyedJaggedTensor.from_lengths_sync(
1213+
keys=["feature_0", "feature_1"],
1214+
values=torch.LongTensor(
1215+
[3, 5, 6, 27, 28, 30],
1216+
),
1217+
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
1218+
weights=None,
1219+
),
1220+
]
1221+
)
1222+
1223+
initial_state_per_rank = [
1224+
{
1225+
"table_0._mch_remapped_ids_mapping": torch.arange(8, dtype=torch.int64),
1226+
"table_1._mch_remapped_ids_mapping": torch.arange(
1227+
16, dtype=torch.int64
1228+
),
1229+
},
1230+
{
1231+
"table_0._mch_remapped_ids_mapping": torch.arange(
1232+
start=8, end=16, dtype=torch.int64
1233+
),
1234+
"table_1._mch_remapped_ids_mapping": torch.arange(
1235+
start=16, end=32, dtype=torch.int64
1236+
),
1237+
},
1238+
]
1239+
max_int = torch.iinfo(torch.int64).max
1240+
1241+
final_state_per_rank = [
1242+
{
1243+
"table_0._mch_sorted_raw_ids": torch.LongTensor(
1244+
[1000, 1001, 1002, 1004] + [max_int] * 4
1245+
),
1246+
"table_1._mch_sorted_raw_ids": torch.LongTensor([max_int] * 16),
1247+
"table_0._mch_remapped_ids_mapping": torch.LongTensor(
1248+
[3, 4, 5, 6, 0, 1, 2, 7]
1249+
),
1250+
"table_1._mch_remapped_ids_mapping": torch.arange(
1251+
16, dtype=torch.int64
1252+
),
1253+
},
1254+
{
1255+
"table_0._mch_sorted_raw_ids": torch.LongTensor([2000] + [max_int] * 7),
1256+
"table_1._mch_sorted_raw_ids": torch.LongTensor(
1257+
[2000, 2001, 2002, 2004] + [max_int] * 12
1258+
),
1259+
"table_0._mch_remapped_ids_mapping": torch.LongTensor(
1260+
[14, 8, 9, 10, 11, 12, 13, 15]
1261+
),
1262+
"table_1._mch_remapped_ids_mapping": torch.LongTensor(
1263+
[27, 29, 28, 30, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31]
1264+
),
1265+
},
1266+
]
1267+
1268+
self._run_multi_process_test(
1269+
callable=_test_in_place_data_overwrite,
1270+
output_keys=["feature_0", "feature_1"],
1271+
world_size=WORLD_SIZE,
1272+
tables=embedding_config,
1273+
kjt_input_per_rank=kjt_input_per_rank,
1274+
kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank,
1275+
initial_state_per_rank=initial_state_per_rank,
1276+
final_state_per_rank=final_state_per_rank,
1277+
sharder=ManagedCollisionEmbeddingCollectionSharder(),
1278+
backend=backend,
1279+
enable_in_place_data_overwrite=enable_in_place_data_overwrite,
1280+
)

0 commit comments

Comments
 (0)