@@ -59,6 +59,7 @@ def __init__(
59
59
device : torch .device ,
60
60
return_remapped : bool = False ,
61
61
input_hash_size : int = 4000 ,
62
+ allow_in_place_embed_weight_update : bool = False ,
62
63
) -> None :
63
64
super ().__init__ ()
64
65
self ._return_remapped = return_remapped
@@ -91,6 +92,7 @@ def __init__(
91
92
embedding_configs = tables ,
92
93
),
93
94
return_remapped_features = self ._return_remapped ,
95
+ allow_in_place_embed_weight_update = allow_in_place_embed_weight_update ,
94
96
)
95
97
)
96
98
@@ -242,6 +244,106 @@ def _test_sharding_and_remapping( # noqa C901
242
244
# TODO: validate embedding rows, and eviction
243
245
244
246
247
+ def _test_in_place_embd_weight_update ( # noqa C901
248
+ output_keys : List [str ],
249
+ tables : List [EmbeddingConfig ],
250
+ rank : int ,
251
+ world_size : int ,
252
+ kjt_input_per_rank : List [KeyedJaggedTensor ],
253
+ kjt_out_per_iter_per_rank : List [List [KeyedJaggedTensor ]],
254
+ initial_state_per_rank : List [Dict [str , torch .Tensor ]],
255
+ final_state_per_rank : List [Dict [str , torch .Tensor ]],
256
+ sharder : ModuleSharder [nn .Module ],
257
+ backend : str ,
258
+ local_size : Optional [int ] = None ,
259
+ input_hash_size : int = 4000 ,
260
+ allow_in_place_embed_weight_update : bool = True ,
261
+ ) -> None :
262
+
263
+ with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
264
+ kjt_input = kjt_input_per_rank [rank ].to (ctx .device )
265
+ kjt_out_per_iter = [
266
+ kjt [rank ].to (ctx .device ) for kjt in kjt_out_per_iter_per_rank
267
+ ]
268
+ return_remapped : bool = True
269
+ sparse_arch = SparseArch (
270
+ tables ,
271
+ torch .device ("meta" ),
272
+ return_remapped = return_remapped ,
273
+ input_hash_size = input_hash_size ,
274
+ allow_in_place_embed_weight_update = allow_in_place_embed_weight_update ,
275
+ )
276
+ apply_optimizer_in_backward (
277
+ RowWiseAdagrad ,
278
+ [
279
+ sparse_arch ._mc_ec ._embedding_collection .embeddings ["table_0" ].weight ,
280
+ sparse_arch ._mc_ec ._embedding_collection .embeddings ["table_1" ].weight ,
281
+ ],
282
+ {"lr" : 0.01 },
283
+ )
284
+ module_sharding_plan = construct_module_sharding_plan (
285
+ sparse_arch ._mc_ec ,
286
+ per_param_sharding = {"table_0" : row_wise (), "table_1" : row_wise ()},
287
+ local_size = local_size ,
288
+ world_size = world_size ,
289
+ device_type = "cuda" if torch .cuda .is_available () else "cpu" ,
290
+ sharder = sharder ,
291
+ )
292
+
293
+ sharded_sparse_arch = _shard_modules (
294
+ module = copy .deepcopy (sparse_arch ),
295
+ plan = ShardingPlan ({"_mc_ec" : module_sharding_plan }),
296
+ # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got
297
+ # `Optional[ProcessGroup]`.
298
+ env = ShardingEnv .from_process_group (ctx .pg ),
299
+ sharders = [sharder ],
300
+ device = ctx .device ,
301
+ )
302
+
303
+ initial_state_dict = sharded_sparse_arch .state_dict ()
304
+ for key , sharded_tensor in initial_state_dict .items ():
305
+ postfix = "." .join (key .split ("." )[- 2 :])
306
+ if postfix in initial_state_per_rank [ctx .rank ]:
307
+ tensor = sharded_tensor .local_shards ()[0 ].tensor .cpu ()
308
+ assert torch .equal (
309
+ tensor , initial_state_per_rank [ctx .rank ][postfix ]
310
+ ), f"initial state { key } on { ctx .rank } does not match, got { tensor } , expect { initial_state_per_rank [rank ][postfix ]} "
311
+
312
+ sharded_sparse_arch .load_state_dict (initial_state_dict )
313
+
314
+ # sharded model
315
+ # each rank gets a subbatch
316
+ loss1 , remapped_ids1 = sharded_sparse_arch (kjt_input )
317
+ loss2 , remapped_ids2 = sharded_sparse_arch (kjt_input )
318
+
319
+ if not allow_in_place_embed_weight_update :
320
+ # Without in-place overwrite the backward pass will fail due to tensor version mismatch
321
+ with unittest .TestCase ().assertRaisesRegex (
322
+ RuntimeError ,
323
+ "one of the variables needed for gradient computation has been modified by an inplace operation" ,
324
+ ):
325
+ loss1 .backward ()
326
+ else :
327
+ loss1 .backward ()
328
+ loss2 .backward ()
329
+ final_state_dict = sharded_sparse_arch .state_dict ()
330
+ for key , sharded_tensor in final_state_dict .items ():
331
+ postfix = "." .join (key .split ("." )[- 2 :])
332
+ if postfix in final_state_per_rank [ctx .rank ]:
333
+ tensor = sharded_tensor .local_shards ()[0 ].tensor .cpu ()
334
+ assert torch .equal (
335
+ tensor , final_state_per_rank [ctx .rank ][postfix ]
336
+ ), f"initial state { key } on { ctx .rank } does not match, got { tensor } , expect { final_state_per_rank [rank ][postfix ]} "
337
+
338
+ remapped_ids = [remapped_ids1 , remapped_ids2 ]
339
+ for key in output_keys :
340
+ for i , kjt_out in enumerate (kjt_out_per_iter ):
341
+ assert torch .equal (
342
+ remapped_ids [i ][key ].values (),
343
+ kjt_out [key ].values (),
344
+ ), f"feature { key } on { ctx .rank } iteration { i } does not match, got { remapped_ids [i ][key ].values ()} , expect { kjt_out [key ].values ()} "
345
+
346
+
245
347
def _test_sharding_and_resharding ( # noqa C901
246
348
tables : List [EmbeddingConfig ],
247
349
rank : int ,
@@ -1016,3 +1118,166 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None:
1016
1118
),
1017
1119
except AssertionError as e :
1018
1120
self .assertTrue ("0 != 1" in str (e ))
1121
+
1122
+ @unittest .skipIf (
1123
+ torch .cuda .device_count () <= 1 ,
1124
+ "Not enough GPUs, this test requires at least two GPUs" ,
1125
+ )
1126
+ # pyre-ignore
1127
+ @given (
1128
+ backend = st .sampled_from (["nccl" ]),
1129
+ allow_in_place_embed_weight_update = st .booleans (),
1130
+ )
1131
+ @settings (deadline = None )
1132
+ def test_in_place_embd_weight_update (
1133
+ self , backend : str , allow_in_place_embed_weight_update : bool
1134
+ ) -> None :
1135
+
1136
+ WORLD_SIZE = 2
1137
+
1138
+ embedding_config = [
1139
+ EmbeddingConfig (
1140
+ name = "table_0" ,
1141
+ feature_names = ["feature_0" ],
1142
+ embedding_dim = 8 ,
1143
+ num_embeddings = 16 ,
1144
+ ),
1145
+ EmbeddingConfig (
1146
+ name = "table_1" ,
1147
+ feature_names = ["feature_1" ],
1148
+ embedding_dim = 8 ,
1149
+ num_embeddings = 32 ,
1150
+ ),
1151
+ ]
1152
+
1153
+ kjt_input_per_rank = [ # noqa
1154
+ KeyedJaggedTensor .from_lengths_sync (
1155
+ keys = ["feature_0" , "feature_1" , "feature_2" ],
1156
+ values = torch .LongTensor (
1157
+ [1000 , 2000 , 1001 , 2000 , 2001 , 2002 , 1 , 1 , 1 ],
1158
+ ),
1159
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]),
1160
+ weights = None ,
1161
+ ),
1162
+ KeyedJaggedTensor .from_lengths_sync (
1163
+ keys = ["feature_0" , "feature_1" , "feature_2" ],
1164
+ values = torch .LongTensor (
1165
+ [
1166
+ 1000 ,
1167
+ 1002 ,
1168
+ 1004 ,
1169
+ 2000 ,
1170
+ 2002 ,
1171
+ 2004 ,
1172
+ 2 ,
1173
+ 2 ,
1174
+ 2 ,
1175
+ ],
1176
+ ),
1177
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]),
1178
+ weights = None ,
1179
+ ),
1180
+ ]
1181
+
1182
+ kjt_out_per_iter_per_rank : List [List [KeyedJaggedTensor ]] = []
1183
+ kjt_out_per_iter_per_rank .append (
1184
+ [
1185
+ KeyedJaggedTensor .from_lengths_sync (
1186
+ keys = ["feature_0" , "feature_1" ],
1187
+ values = torch .LongTensor (
1188
+ [7 , 15 , 7 , 31 , 31 , 31 ],
1189
+ ),
1190
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1191
+ weights = None ,
1192
+ ),
1193
+ KeyedJaggedTensor .from_lengths_sync (
1194
+ keys = ["feature_0" , "feature_1" ],
1195
+ values = torch .LongTensor (
1196
+ [7 , 7 , 7 , 31 , 31 , 31 ],
1197
+ ),
1198
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1199
+ weights = None ,
1200
+ ),
1201
+ ]
1202
+ )
1203
+ # TODO: cleanup sorting so more dedugable/logical initial fill
1204
+
1205
+ kjt_out_per_iter_per_rank .append (
1206
+ [
1207
+ KeyedJaggedTensor .from_lengths_sync (
1208
+ keys = ["feature_0" , "feature_1" ],
1209
+ values = torch .LongTensor (
1210
+ [3 , 14 , 4 , 27 , 29 , 28 ],
1211
+ ),
1212
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1213
+ weights = None ,
1214
+ ),
1215
+ KeyedJaggedTensor .from_lengths_sync (
1216
+ keys = ["feature_0" , "feature_1" ],
1217
+ values = torch .LongTensor (
1218
+ [3 , 5 , 6 , 27 , 28 , 30 ],
1219
+ ),
1220
+ lengths = torch .LongTensor ([1 , 1 , 1 , 1 , 1 , 1 ]),
1221
+ weights = None ,
1222
+ ),
1223
+ ]
1224
+ )
1225
+
1226
+ initial_state_per_rank = [
1227
+ {
1228
+ "table_0._mch_remapped_ids_mapping" : torch .arange (8 , dtype = torch .int64 ),
1229
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1230
+ 16 , dtype = torch .int64
1231
+ ),
1232
+ },
1233
+ {
1234
+ "table_0._mch_remapped_ids_mapping" : torch .arange (
1235
+ start = 8 , end = 16 , dtype = torch .int64
1236
+ ),
1237
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1238
+ start = 16 , end = 32 , dtype = torch .int64
1239
+ ),
1240
+ },
1241
+ ]
1242
+ max_int = torch .iinfo (torch .int64 ).max
1243
+
1244
+ final_state_per_rank = [
1245
+ {
1246
+ "table_0._mch_sorted_raw_ids" : torch .LongTensor (
1247
+ [1000 , 1001 , 1002 , 1004 ] + [max_int ] * 4
1248
+ ),
1249
+ "table_1._mch_sorted_raw_ids" : torch .LongTensor ([max_int ] * 16 ),
1250
+ "table_0._mch_remapped_ids_mapping" : torch .LongTensor (
1251
+ [3 , 4 , 5 , 6 , 0 , 1 , 2 , 7 ]
1252
+ ),
1253
+ "table_1._mch_remapped_ids_mapping" : torch .arange (
1254
+ 16 , dtype = torch .int64
1255
+ ),
1256
+ },
1257
+ {
1258
+ "table_0._mch_sorted_raw_ids" : torch .LongTensor ([2000 ] + [max_int ] * 7 ),
1259
+ "table_1._mch_sorted_raw_ids" : torch .LongTensor (
1260
+ [2000 , 2001 , 2002 , 2004 ] + [max_int ] * 12
1261
+ ),
1262
+ "table_0._mch_remapped_ids_mapping" : torch .LongTensor (
1263
+ [14 , 8 , 9 , 10 , 11 , 12 , 13 , 15 ]
1264
+ ),
1265
+ "table_1._mch_remapped_ids_mapping" : torch .LongTensor (
1266
+ [27 , 29 , 28 , 30 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 31 ]
1267
+ ),
1268
+ },
1269
+ ]
1270
+
1271
+ self ._run_multi_process_test (
1272
+ callable = _test_in_place_embd_weight_update ,
1273
+ output_keys = ["feature_0" , "feature_1" ],
1274
+ world_size = WORLD_SIZE ,
1275
+ tables = embedding_config ,
1276
+ kjt_input_per_rank = kjt_input_per_rank ,
1277
+ kjt_out_per_iter_per_rank = kjt_out_per_iter_per_rank ,
1278
+ initial_state_per_rank = initial_state_per_rank ,
1279
+ final_state_per_rank = final_state_per_rank ,
1280
+ sharder = ManagedCollisionEmbeddingCollectionSharder (),
1281
+ backend = backend ,
1282
+ allow_in_place_embed_weight_update = allow_in_place_embed_weight_update ,
1283
+ )
0 commit comments