@@ -59,6 +59,7 @@ def __init__(
59
59
device : torch .device ,
60
60
return_remapped : bool = False ,
61
61
input_hash_size : int = 4000 ,
62
+ enable_in_place_data_overwrite : bool = False ,
62
63
) -> None :
63
64
super ().__init__ ()
64
65
self ._return_remapped = return_remapped
@@ -91,6 +92,7 @@ def __init__(
91
92
embedding_configs = tables ,
92
93
),
93
94
return_remapped_features = self ._return_remapped ,
95
+ enable_in_place_data_overwrite = enable_in_place_data_overwrite ,
94
96
)
95
97
)
96
98
@@ -242,6 +244,104 @@ def _test_sharding_and_remapping( # noqa C901
242
244
# TODO: validate embedding rows, and eviction
243
245
244
246
247
+ def _test_in_place_data_overwrite ( # noqa C901
248
+ output_keys : List [str ],
249
+ tables : List [EmbeddingConfig ],
250
+ rank : int ,
251
+ world_size : int ,
252
+ kjt_input_per_rank : List [KeyedJaggedTensor ],
253
+ kjt_out_per_iter_per_rank : List [List [KeyedJaggedTensor ]],
254
+ initial_state_per_rank : List [Dict [str , torch .Tensor ]],
255
+ final_state_per_rank : List [Dict [str , torch .Tensor ]],
256
+ sharder : ModuleSharder [nn .Module ],
257
+ backend : str ,
258
+ local_size : Optional [int ] = None ,
259
+ input_hash_size : int = 4000 ,
260
+ enable_in_place_data_overwrite : bool = True ,
261
+ ) -> None :
262
+
263
+ with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
264
+ kjt_input = kjt_input_per_rank [rank ].to (ctx .device )
265
+ kjt_out_per_iter = [
266
+ kjt [rank ].to (ctx .device ) for kjt in kjt_out_per_iter_per_rank
267
+ ]
268
+ return_remapped : bool = True
269
+ sparse_arch = SparseArch (
270
+ tables ,
271
+ torch .device ("meta" ),
272
+ return_remapped = return_remapped ,
273
+ input_hash_size = input_hash_size ,
274
+ )
275
+
276
+ apply_optimizer_in_backward (
277
+ RowWiseAdagrad ,
278
+ [
279
+ sparse_arch ._mc_ec ._embedding_collection .embeddings ["table_0" ].weight ,
280
+ sparse_arch ._mc_ec ._embedding_collection .embeddings ["table_1" ].weight ,
281
+ ],
282
+ {"lr" : 0.01 },
283
+ )
284
+ module_sharding_plan = construct_module_sharding_plan (
285
+ sparse_arch ._mc_ec ,
286
+ per_param_sharding = {"table_0" : row_wise (), "table_1" : row_wise ()},
287
+ local_size = local_size ,
288
+ world_size = world_size ,
289
+ device_type = "cuda" if torch .cuda .is_available () else "cpu" ,
290
+ sharder = sharder ,
291
+ )
292
+
293
+ sharded_sparse_arch = _shard_modules (
294
+ module = copy .deepcopy (sparse_arch ),
295
+ plan = ShardingPlan ({"_mc_ec" : module_sharding_plan }),
296
+ # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got
297
+ # `Optional[ProcessGroup]`.
298
+ env = ShardingEnv .from_process_group (ctx .pg ),
299
+ sharders = [sharder ],
300
+ device = ctx .device ,
301
+ )
302
+
303
+ initial_state_dict = sharded_sparse_arch .state_dict ()
304
+ for key , sharded_tensor in initial_state_dict .items ():
305
+ postfix = "." .join (key .split ("." )[- 2 :])
306
+ if postfix in initial_state_per_rank [ctx .rank ]:
307
+ tensor = sharded_tensor .local_shards ()[0 ].tensor .cpu ()
308
+ assert torch .equal (
309
+ tensor , initial_state_per_rank [ctx .rank ][postfix ]
310
+ ), f"initial state { key } on { ctx .rank } does not match, got { tensor } , expect { initial_state_per_rank [rank ][postfix ]} "
311
+
312
+ sharded_sparse_arch .load_state_dict (initial_state_dict )
313
+
314
+ # sharded model
315
+ # each rank gets a subbatch
316
+ loss1 , remapped_ids1 = sharded_sparse_arch (kjt_input )
317
+ loss2 , remapped_ids2 = sharded_sparse_arch (kjt_input )
318
+ if not enable_in_place_data_overwrite :
319
+ # Without in-place overwrite the backward pass will fail due to tensor version mismatch
320
+ with unittest .TestCase ().assertRaisesRegex (
321
+ RuntimeError ,
322
+ "one of the variables needed for gradient computation has been modified by an inplace operation" ,
323
+ ):
324
+ loss1 .backward ()
325
+ loss2 .backward ()
326
+
327
+ final_state_dict = sharded_sparse_arch .state_dict ()
328
+ for key , sharded_tensor in final_state_dict .items ():
329
+ postfix = "." .join (key .split ("." )[- 2 :])
330
+ if postfix in final_state_per_rank [ctx .rank ]:
331
+ tensor = sharded_tensor .local_shards ()[0 ].tensor .cpu ()
332
+ assert torch .equal (
333
+ tensor , final_state_per_rank [ctx .rank ][postfix ]
334
+ ), f"initial state { key } on { ctx .rank } does not match, got { tensor } , expect { final_state_per_rank [rank ][postfix ]} "
335
+
336
+ remapped_ids = [remapped_ids1 , remapped_ids2 ]
337
+ for key in output_keys :
338
+ for i , kjt_out in enumerate (kjt_out_per_iter ):
339
+ assert torch .equal (
340
+ remapped_ids [i ][key ].values (),
341
+ kjt_out [key ].values (),
342
+ ), f"feature { key } on { ctx .rank } iteration { i } does not match, got { remapped_ids [i ][key ].values ()} , expect { kjt_out [key ].values ()} "
343
+
344
+
245
345
def _test_sharding_and_resharding ( # noqa C901
246
346
tables : List [EmbeddingConfig ],
247
347
rank : int ,
@@ -1016,3 +1116,165 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None:
1016
1116
),
1017
1117
except AssertionError as e :
1018
1118
self .assertTrue ("0 != 1" in str (e ))
1119
+
1120
+ @unittest .skipIf (
1121
+ torch .cuda .device_count () <= 1 ,
1122
+ "Not enough GPUs, this test requires at least two GPUs" ,
1123
+ )
1124
+ # pyre-ignore
1125
+ @given (
1126
+ backend = st .sampled_from (["nccl" ]), enable_in_place_data_overwrite = st .booleans ()
1127
+ )
1128
+ @settings (deadline = None )
1129
+ def test_in_place_data_overwrite (
1130
+ self , backend : str , enable_in_place_data_overwrite : bool
1131
+ ) -> None :
1132
+
1133
+ WORLD_SIZE = 2
1134
+
1135
+ embedding_config = [
1136
+ EmbeddingConfig (
1137
+ name = "table_0" ,
1138
+ feature_names = ["feature_0" ],
1139
+ embedding_dim = 8 ,
1140
+ num_embeddings = 16 ,
1141
+ ),
1142
+ EmbeddingConfig (
1143
+ name = "table_1" ,
1144
+ feature_names = ["feature_1" ],
1145
+ embedding_dim = 8 ,
1146
+ num_embeddings = 32 ,
1147
+ ),
1148
+ ]
1149
+
1150
+ kjt_input_per_rank = [ # noqa
1151
+ KeyedJaggedTensor .from_lengths_sync (
1152
+ keys = ["feature_0" , "feature_1" , "feature_2" ],
1153
+ values = torch .LongTensor (
1154
+ [1000 , 2000 , 1001 , 2000 , 2001 , 2002 , 1 , 1 , 1 ],
1155
+ ),
1156
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]),
1157
+ weights = None ,
1158
+ ),
1159
+ KeyedJaggedTensor .from_lengths_sync (
1160
+ keys = ["feature_0" , "feature_1" , "feature_2" ],
1161
+ values = torch .LongTensor (
1162
+ [
1163
+ 1000 ,
1164
+ 1002 ,
1165
+ 1004 ,
1166
+ 2000 ,
1167
+ 2002 ,
1168
+ 2004 ,
1169
+ 2 ,
1170
+ 2 ,
1171
+ 2 ,
1172
+ ],
1173
+ ),
1174
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]),
1175
+ weights = None ,
1176
+ ),
1177
+ ]
1178
+
1179
+ kjt_out_per_iter_per_rank : List [List [KeyedJaggedTensor ]] = []
1180
+ kjt_out_per_iter_per_rank .append (
1181
+ [
1182
+ KeyedJaggedTensor .from_lengths_sync (
1183
+ keys = ["feature_0" , "feature_1" ],
1184
+ values = torch .LongTensor (
1185
+ [7 , 15 , 7 , 31 , 31 , 31 ],
1186
+ ),
1187
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1188
+ weights = None ,
1189
+ ),
1190
+ KeyedJaggedTensor .from_lengths_sync (
1191
+ keys = ["feature_0" , "feature_1" ],
1192
+ values = torch .LongTensor (
1193
+ [7 , 7 , 7 , 31 , 31 , 31 ],
1194
+ ),
1195
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1196
+ weights = None ,
1197
+ ),
1198
+ ]
1199
+ )
1200
+ # TODO: cleanup sorting so more dedugable/logical initial fill
1201
+
1202
+ kjt_out_per_iter_per_rank .append (
1203
+ [
1204
+ KeyedJaggedTensor .from_lengths_sync (
1205
+ keys = ["feature_0" , "feature_1" ],
1206
+ values = torch .LongTensor (
1207
+ [3 , 14 , 4 , 27 , 29 , 28 ],
1208
+ ),
1209
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1210
+ weights = None ,
1211
+ ),
1212
+ KeyedJaggedTensor .from_lengths_sync (
1213
+ keys = ["feature_0" , "feature_1" ],
1214
+ values = torch .LongTensor (
1215
+ [3 , 5 , 6 , 27 , 28 , 30 ],
1216
+ ),
1217
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1218
+ weights = None ,
1219
+ ),
1220
+ ]
1221
+ )
1222
+
1223
+ initial_state_per_rank = [
1224
+ {
1225
+ "table_0._mch_remapped_ids_mapping" : torch .arange (8 , dtype = torch .int64 ),
1226
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1227
+ 16 , dtype = torch .int64
1228
+ ),
1229
+ },
1230
+ {
1231
+ "table_0._mch_remapped_ids_mapping" : torch .arange (
1232
+ start = 8 , end = 16 , dtype = torch .int64
1233
+ ),
1234
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1235
+ start = 16 , end = 32 , dtype = torch .int64
1236
+ ),
1237
+ },
1238
+ ]
1239
+ max_int = torch .iinfo (torch .int64 ).max
1240
+
1241
+ final_state_per_rank = [
1242
+ {
1243
+ "table_0._mch_sorted_raw_ids" : torch .LongTensor (
1244
+ [1000 , 1001 , 1002 , 1004 ] + [max_int ] * 4
1245
+ ),
1246
+ "table_1._mch_sorted_raw_ids" : torch .LongTensor ([max_int ] * 16 ),
1247
+ "table_0._mch_remapped_ids_mapping" : torch .LongTensor (
1248
+ [3 , 4 , 5 , 6 , 0 , 1 , 2 , 7 ]
1249
+ ),
1250
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1251
+ 16 , dtype = torch .int64
1252
+ ),
1253
+ },
1254
+ {
1255
+ "table_0._mch_sorted_raw_ids" : torch .LongTensor ([2000 ] + [max_int ] * 7 ),
1256
+ "table_1._mch_sorted_raw_ids" : torch .LongTensor (
1257
+ [2000 , 2001 , 2002 , 2004 ] + [max_int ] * 12
1258
+ ),
1259
+ "table_0._mch_remapped_ids_mapping" : torch .LongTensor (
1260
+ [14 , 8 , 9 , 10 , 11 , 12 , 13 , 15 ]
1261
+ ),
1262
+ "table_1._mch_remapped_ids_mapping" : torch .LongTensor (
1263
+ [27 , 29 , 28 , 30 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 31 ]
1264
+ ),
1265
+ },
1266
+ ]
1267
+
1268
+ self ._run_multi_process_test (
1269
+ callable = _test_in_place_data_overwrite ,
1270
+ output_keys = ["feature_0" , "feature_1" ],
1271
+ world_size = WORLD_SIZE ,
1272
+ tables = embedding_config ,
1273
+ kjt_input_per_rank = kjt_input_per_rank ,
1274
+ kjt_out_per_iter_per_rank = kjt_out_per_iter_per_rank ,
1275
+ initial_state_per_rank = initial_state_per_rank ,
1276
+ final_state_per_rank = final_state_per_rank ,
1277
+ sharder = ManagedCollisionEmbeddingCollectionSharder (),
1278
+ backend = backend ,
1279
+ enable_in_place_data_overwrite = enable_in_place_data_overwrite ,
1280
+ )
0 commit comments