@@ -59,6 +59,7 @@ def __init__(
59
59
device : torch .device ,
60
60
return_remapped : bool = False ,
61
61
input_hash_size : int = 4000 ,
62
+ allow_in_place_embed_weight_update : bool = False ,
62
63
) -> None :
63
64
super ().__init__ ()
64
65
self ._return_remapped = return_remapped
@@ -91,6 +92,7 @@ def __init__(
91
92
embedding_configs = tables ,
92
93
),
93
94
return_remapped_features = self ._return_remapped ,
95
+ allow_in_place_embed_weight_update = allow_in_place_embed_weight_update ,
94
96
)
95
97
)
96
98
@@ -242,6 +244,104 @@ def _test_sharding_and_remapping( # noqa C901
242
244
# TODO: validate embedding rows, and eviction
243
245
244
246
247
+ def _test_in_place_embd_weight_update ( # noqa C901
248
+ output_keys : List [str ],
249
+ tables : List [EmbeddingConfig ],
250
+ rank : int ,
251
+ world_size : int ,
252
+ kjt_input_per_rank : List [KeyedJaggedTensor ],
253
+ kjt_out_per_iter_per_rank : List [List [KeyedJaggedTensor ]],
254
+ initial_state_per_rank : List [Dict [str , torch .Tensor ]],
255
+ final_state_per_rank : List [Dict [str , torch .Tensor ]],
256
+ sharder : ModuleSharder [nn .Module ],
257
+ backend : str ,
258
+ local_size : Optional [int ] = None ,
259
+ input_hash_size : int = 4000 ,
260
+ allow_in_place_embed_weight_update : bool = True ,
261
+ ) -> None :
262
+
263
+ with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
264
+ kjt_input = kjt_input_per_rank [rank ].to (ctx .device )
265
+ kjt_out_per_iter = [
266
+ kjt [rank ].to (ctx .device ) for kjt in kjt_out_per_iter_per_rank
267
+ ]
268
+ return_remapped : bool = True
269
+ sparse_arch = SparseArch (
270
+ tables ,
271
+ torch .device ("meta" ),
272
+ return_remapped = return_remapped ,
273
+ input_hash_size = input_hash_size ,
274
+ )
275
+
276
+ apply_optimizer_in_backward (
277
+ RowWiseAdagrad ,
278
+ [
279
+ sparse_arch ._mc_ec ._embedding_collection .embeddings ["table_0" ].weight ,
280
+ sparse_arch ._mc_ec ._embedding_collection .embeddings ["table_1" ].weight ,
281
+ ],
282
+ {"lr" : 0.01 },
283
+ )
284
+ module_sharding_plan = construct_module_sharding_plan (
285
+ sparse_arch ._mc_ec ,
286
+ per_param_sharding = {"table_0" : row_wise (), "table_1" : row_wise ()},
287
+ local_size = local_size ,
288
+ world_size = world_size ,
289
+ device_type = "cuda" if torch .cuda .is_available () else "cpu" ,
290
+ sharder = sharder ,
291
+ )
292
+
293
+ sharded_sparse_arch = _shard_modules (
294
+ module = copy .deepcopy (sparse_arch ),
295
+ plan = ShardingPlan ({"_mc_ec" : module_sharding_plan }),
296
+ # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got
297
+ # `Optional[ProcessGroup]`.
298
+ env = ShardingEnv .from_process_group (ctx .pg ),
299
+ sharders = [sharder ],
300
+ device = ctx .device ,
301
+ )
302
+
303
+ initial_state_dict = sharded_sparse_arch .state_dict ()
304
+ for key , sharded_tensor in initial_state_dict .items ():
305
+ postfix = "." .join (key .split ("." )[- 2 :])
306
+ if postfix in initial_state_per_rank [ctx .rank ]:
307
+ tensor = sharded_tensor .local_shards ()[0 ].tensor .cpu ()
308
+ assert torch .equal (
309
+ tensor , initial_state_per_rank [ctx .rank ][postfix ]
310
+ ), f"initial state { key } on { ctx .rank } does not match, got { tensor } , expect { initial_state_per_rank [rank ][postfix ]} "
311
+
312
+ sharded_sparse_arch .load_state_dict (initial_state_dict )
313
+
314
+ # sharded model
315
+ # each rank gets a subbatch
316
+ loss1 , remapped_ids1 = sharded_sparse_arch (kjt_input )
317
+ loss2 , remapped_ids2 = sharded_sparse_arch (kjt_input )
318
+ if not allow_in_place_embed_weight_update :
319
+ # Without in-place overwrite the backward pass will fail due to tensor version mismatch
320
+ with unittest .TestCase ().assertRaisesRegex (
321
+ RuntimeError ,
322
+ "one of the variables needed for gradient computation has been modified by an inplace operation" ,
323
+ ):
324
+ loss1 .backward ()
325
+ loss2 .backward ()
326
+
327
+ final_state_dict = sharded_sparse_arch .state_dict ()
328
+ for key , sharded_tensor in final_state_dict .items ():
329
+ postfix = "." .join (key .split ("." )[- 2 :])
330
+ if postfix in final_state_per_rank [ctx .rank ]:
331
+ tensor = sharded_tensor .local_shards ()[0 ].tensor .cpu ()
332
+ assert torch .equal (
333
+ tensor , final_state_per_rank [ctx .rank ][postfix ]
334
+ ), f"initial state { key } on { ctx .rank } does not match, got { tensor } , expect { final_state_per_rank [rank ][postfix ]} "
335
+
336
+ remapped_ids = [remapped_ids1 , remapped_ids2 ]
337
+ for key in output_keys :
338
+ for i , kjt_out in enumerate (kjt_out_per_iter ):
339
+ assert torch .equal (
340
+ remapped_ids [i ][key ].values (),
341
+ kjt_out [key ].values (),
342
+ ), f"feature { key } on { ctx .rank } iteration { i } does not match, got { remapped_ids [i ][key ].values ()} , expect { kjt_out [key ].values ()} "
343
+
344
+
245
345
def _test_sharding_and_resharding ( # noqa C901
246
346
tables : List [EmbeddingConfig ],
247
347
rank : int ,
@@ -1016,3 +1116,166 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None:
1016
1116
),
1017
1117
except AssertionError as e :
1018
1118
self .assertTrue ("0 != 1" in str (e ))
1119
+
1120
+ @unittest .skipIf (
1121
+ torch .cuda .device_count () <= 1 ,
1122
+ "Not enough GPUs, this test requires at least two GPUs" ,
1123
+ )
1124
+ # pyre-ignore
1125
+ @given (
1126
+ backend = st .sampled_from (["nccl" ]),
1127
+ allow_in_place_embed_weight_update = st .booleans (),
1128
+ )
1129
+ @settings (deadline = None )
1130
+ def test_in_place_embd_weight_update (
1131
+ self , backend : str , allow_in_place_embed_weight_update : bool
1132
+ ) -> None :
1133
+
1134
+ WORLD_SIZE = 2
1135
+
1136
+ embedding_config = [
1137
+ EmbeddingConfig (
1138
+ name = "table_0" ,
1139
+ feature_names = ["feature_0" ],
1140
+ embedding_dim = 8 ,
1141
+ num_embeddings = 16 ,
1142
+ ),
1143
+ EmbeddingConfig (
1144
+ name = "table_1" ,
1145
+ feature_names = ["feature_1" ],
1146
+ embedding_dim = 8 ,
1147
+ num_embeddings = 32 ,
1148
+ ),
1149
+ ]
1150
+
1151
+ kjt_input_per_rank = [ # noqa
1152
+ KeyedJaggedTensor .from_lengths_sync (
1153
+ keys = ["feature_0" , "feature_1" , "feature_2" ],
1154
+ values = torch .LongTensor (
1155
+ [1000 , 2000 , 1001 , 2000 , 2001 , 2002 , 1 , 1 , 1 ],
1156
+ ),
1157
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]),
1158
+ weights = None ,
1159
+ ),
1160
+ KeyedJaggedTensor .from_lengths_sync (
1161
+ keys = ["feature_0" , "feature_1" , "feature_2" ],
1162
+ values = torch .LongTensor (
1163
+ [
1164
+ 1000 ,
1165
+ 1002 ,
1166
+ 1004 ,
1167
+ 2000 ,
1168
+ 2002 ,
1169
+ 2004 ,
1170
+ 2 ,
1171
+ 2 ,
1172
+ 2 ,
1173
+ ],
1174
+ ),
1175
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]),
1176
+ weights = None ,
1177
+ ),
1178
+ ]
1179
+
1180
+ kjt_out_per_iter_per_rank : List [List [KeyedJaggedTensor ]] = []
1181
+ kjt_out_per_iter_per_rank .append (
1182
+ [
1183
+ KeyedJaggedTensor .from_lengths_sync (
1184
+ keys = ["feature_0" , "feature_1" ],
1185
+ values = torch .LongTensor (
1186
+ [7 , 15 , 7 , 31 , 31 , 31 ],
1187
+ ),
1188
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1189
+ weights = None ,
1190
+ ),
1191
+ KeyedJaggedTensor .from_lengths_sync (
1192
+ keys = ["feature_0" , "feature_1" ],
1193
+ values = torch .LongTensor (
1194
+ [7 , 7 , 7 , 31 , 31 , 31 ],
1195
+ ),
1196
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1197
+ weights = None ,
1198
+ ),
1199
+ ]
1200
+ )
1201
+ # TODO: cleanup sorting so more dedugable/logical initial fill
1202
+
1203
+ kjt_out_per_iter_per_rank .append (
1204
+ [
1205
+ KeyedJaggedTensor .from_lengths_sync (
1206
+ keys = ["feature_0" , "feature_1" ],
1207
+ values = torch .LongTensor (
1208
+ [3 , 14 , 4 , 27 , 29 , 28 ],
1209
+ ),
1210
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1211
+ weights = None ,
1212
+ ),
1213
+ KeyedJaggedTensor .from_lengths_sync (
1214
+ keys = ["feature_0" , "feature_1" ],
1215
+ values = torch .LongTensor (
1216
+ [3 , 5 , 6 , 27 , 28 , 30 ],
1217
+ ),
1218
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1219
+ weights = None ,
1220
+ ),
1221
+ ]
1222
+ )
1223
+
1224
+ initial_state_per_rank = [
1225
+ {
1226
+ "table_0._mch_remapped_ids_mapping" : torch .arange (8 , dtype = torch .int64 ),
1227
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1228
+ 16 , dtype = torch .int64
1229
+ ),
1230
+ },
1231
+ {
1232
+ "table_0._mch_remapped_ids_mapping" : torch .arange (
1233
+ start = 8 , end = 16 , dtype = torch .int64
1234
+ ),
1235
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1236
+ start = 16 , end = 32 , dtype = torch .int64
1237
+ ),
1238
+ },
1239
+ ]
1240
+ max_int = torch .iinfo (torch .int64 ).max
1241
+
1242
+ final_state_per_rank = [
1243
+ {
1244
+ "table_0._mch_sorted_raw_ids" : torch .LongTensor (
1245
+ [1000 , 1001 , 1002 , 1004 ] + [max_int ] * 4
1246
+ ),
1247
+ "table_1._mch_sorted_raw_ids" : torch .LongTensor ([max_int ] * 16 ),
1248
+ "table_0._mch_remapped_ids_mapping" : torch .LongTensor (
1249
+ [3 , 4 , 5 , 6 , 0 , 1 , 2 , 7 ]
1250
+ ),
1251
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1252
+ 16 , dtype = torch .int64
1253
+ ),
1254
+ },
1255
+ {
1256
+ "table_0._mch_sorted_raw_ids" : torch .LongTensor ([2000 ] + [max_int ] * 7 ),
1257
+ "table_1._mch_sorted_raw_ids" : torch .LongTensor (
1258
+ [2000 , 2001 , 2002 , 2004 ] + [max_int ] * 12
1259
+ ),
1260
+ "table_0._mch_remapped_ids_mapping" : torch .LongTensor (
1261
+ [14 , 8 , 9 , 10 , 11 , 12 , 13 , 15 ]
1262
+ ),
1263
+ "table_1._mch_remapped_ids_mapping" : torch .LongTensor (
1264
+ [27 , 29 , 28 , 30 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 31 ]
1265
+ ),
1266
+ },
1267
+ ]
1268
+
1269
+ self ._run_multi_process_test (
1270
+ callable = _test_in_place_embd_weight_update ,
1271
+ output_keys = ["feature_0" , "feature_1" ],
1272
+ world_size = WORLD_SIZE ,
1273
+ tables = embedding_config ,
1274
+ kjt_input_per_rank = kjt_input_per_rank ,
1275
+ kjt_out_per_iter_per_rank = kjt_out_per_iter_per_rank ,
1276
+ initial_state_per_rank = initial_state_per_rank ,
1277
+ final_state_per_rank = final_state_per_rank ,
1278
+ sharder = ManagedCollisionEmbeddingCollectionSharder (),
1279
+ backend = backend ,
1280
+ allow_in_place_embed_weight_update = allow_in_place_embed_weight_update ,
1281
+ )
0 commit comments