|
9 | 9 |
|
10 | 10 | import copy
|
11 | 11 | from dataclasses import dataclass
|
| 12 | +from functools import partial |
12 | 13 | from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
|
13 | 14 |
|
14 | 15 | import torch
|
@@ -46,6 +47,47 @@ def _append_table_shard(
|
46 | 47 | d[table_name].append(shard)
|
47 | 48 |
|
48 | 49 |
|
| 50 | +def post_state_dict_hook( |
| 51 | + # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"] |
| 52 | + # pyre-ignore [24] |
| 53 | + module: ShardedEmbeddingModule, |
| 54 | + destination: Dict[str, torch.Tensor], |
| 55 | + prefix: str, |
| 56 | + _local_metadata: Dict[str, Any], |
| 57 | + tables_weights_prefix: str, # "embedding_bags" or "embeddings" |
| 58 | +) -> None: |
| 59 | + for ( |
| 60 | + table_name, |
| 61 | + sharded_t, |
| 62 | + ) in module._table_name_to_sharded_tensor.items(): |
| 63 | + destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = sharded_t |
| 64 | + |
| 65 | + for sfx, dict_sharded_t, dict_t_list in [ |
| 66 | + ( |
| 67 | + "weight_qscale", |
| 68 | + module._table_name_to_sharded_tensor_qscale, |
| 69 | + module._table_name_to_tensors_list_qscale, |
| 70 | + ), |
| 71 | + ( |
| 72 | + "weight_qbias", |
| 73 | + module._table_name_to_sharded_tensor_qbias, |
| 74 | + module._table_name_to_tensors_list_qbias, |
| 75 | + ), |
| 76 | + ]: |
| 77 | + for ( |
| 78 | + table_name, |
| 79 | + sharded_t, |
| 80 | + ) in dict_sharded_t.items(): |
| 81 | + destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = ( |
| 82 | + sharded_t |
| 83 | + ) |
| 84 | + for ( |
| 85 | + table_name, |
| 86 | + t_list, |
| 87 | + ) in dict_t_list.items(): |
| 88 | + destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = t_list |
| 89 | + |
| 90 | + |
49 | 91 | class ShardedQuantEmbeddingModuleState(
|
50 | 92 | ShardedEmbeddingModule[CompIn, DistOut, Out, ShrdCtx]
|
51 | 93 | ):
|
@@ -82,17 +124,6 @@ def _initialize_torch_state( # noqa: C901
|
82 | 124 | ] = {}
|
83 | 125 | self._table_name_to_tensors_list_qbias: Dict[str, List[torch.Tensor]] = {}
|
84 | 126 |
|
85 |
| - # pruning_index_remappings |
86 |
| - self._table_name_to_local_shards_pruning_index_remappings: Dict[ |
87 |
| - str, List[Shard] |
88 |
| - ] = {} |
89 |
| - self._table_name_to_sharded_tensor_pruning_index_remappings: Dict[ |
90 |
| - str, Union[torch.Tensor, ShardedTensorBase] |
91 |
| - ] = {} |
92 |
| - self._table_name_to_tensors_list_pruning_index_remappings: Dict[ |
93 |
| - str, List[torch.Tensor] |
94 |
| - ] = {} |
95 |
| - |
96 | 127 | for tbe, config in tbes.items():
|
97 | 128 | for (tbe_split_w, tbe_split_qscale, tbe_split_qbias), table in zip(
|
98 | 129 | tbe.split_embedding_weights_with_scale_bias(split_scale_bias_mode=2),
|
@@ -184,43 +215,6 @@ def _initialize_torch_state( # noqa: C901
|
184 | 215 | Shard(tensor=tbe_split_qparam, metadata=qmetadata),
|
185 | 216 | )
|
186 | 217 | # end of weight_qscale & weight_qbias section
|
187 |
| - if table.pruning_indices_remapping is not None: |
188 |
| - for ( |
189 |
| - qparam, |
190 |
| - table_name_to_local_shards, |
191 |
| - _, |
192 |
| - ) in [ |
193 |
| - ( |
194 |
| - table.pruning_indices_remapping, |
195 |
| - self._table_name_to_local_shards_pruning_index_remappings, |
196 |
| - self._table_name_to_tensors_list_pruning_index_remappings, |
197 |
| - ) |
198 |
| - ]: |
199 |
| - parameter_sharding: ParameterSharding = ( |
200 |
| - table_name_to_parameter_sharding[table.name] |
201 |
| - ) |
202 |
| - sharding_type: str = parameter_sharding.sharding_type |
203 |
| - |
204 |
| - assert sharding_type in [ |
205 |
| - ShardingType.TABLE_WISE.value, |
206 |
| - ShardingType.COLUMN_WISE.value, |
207 |
| - ] |
208 |
| - |
209 |
| - qmetadata = ShardMetadata( |
210 |
| - shard_offsets=[0], |
211 |
| - shard_sizes=[ |
212 |
| - qparam.shape[0], |
213 |
| - ], |
214 |
| - placement=table.local_metadata.placement, |
215 |
| - ) |
216 |
| - # TODO(ivankobzarev): "meta" sharding support: cleanup when copy to "meta" moves all tensors to "meta" |
217 |
| - if qmetadata.placement.device != qparam.device: |
218 |
| - qmetadata.placement = _remote_device(qparam.device) |
219 |
| - _append_table_shard( |
220 |
| - table_name_to_local_shards, |
221 |
| - table.name, |
222 |
| - Shard(tensor=qparam, metadata=qmetadata), |
223 |
| - ) |
224 | 218 |
|
225 | 219 | for table_name_to_local_shards, table_name_to_sharded_tensor in [
|
226 | 220 | (self._table_name_to_local_shards, self._table_name_to_sharded_tensor),
|
@@ -263,65 +257,9 @@ def _initialize_torch_state( # noqa: C901
|
263 | 257 | )
|
264 | 258 | )
|
265 | 259 |
|
266 |
| - for table_name_to_local_shards, table_name_to_sharded_tensor in [ |
267 |
| - ( |
268 |
| - self._table_name_to_local_shards_pruning_index_remappings, |
269 |
| - self._table_name_to_sharded_tensor_pruning_index_remappings, |
270 |
| - ), |
271 |
| - ]: |
272 |
| - for table_name, local_shards in table_name_to_local_shards.items(): |
273 |
| - # Single Tensor per table (TW sharding) |
274 |
| - table_name_to_sharded_tensor[table_name] = local_shards[0].tensor |
275 |
| - continue |
276 |
| - |
277 |
| - def post_state_dict_hook( |
278 |
| - # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"] |
279 |
| - module: ShardedQuantEmbeddingModuleState[CompIn, DistOut, Out, ShrdCtx], |
280 |
| - destination: Dict[str, torch.Tensor], |
281 |
| - prefix: str, |
282 |
| - _local_metadata: Dict[str, Any], |
283 |
| - ) -> None: |
284 |
| - for ( |
285 |
| - table_name, |
286 |
| - sharded_t, |
287 |
| - ) in module._table_name_to_sharded_tensor.items(): |
288 |
| - destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = ( |
289 |
| - sharded_t |
290 |
| - ) |
291 |
| - |
292 |
| - for sfx, dict_sharded_t, dict_t_list in [ |
293 |
| - ( |
294 |
| - "weight_qscale", |
295 |
| - module._table_name_to_sharded_tensor_qscale, |
296 |
| - module._table_name_to_tensors_list_qscale, |
297 |
| - ), |
298 |
| - ( |
299 |
| - "weight_qbias", |
300 |
| - module._table_name_to_sharded_tensor_qbias, |
301 |
| - module._table_name_to_tensors_list_qbias, |
302 |
| - ), |
303 |
| - ( |
304 |
| - "index_remappings_array", |
305 |
| - module._table_name_to_sharded_tensor_pruning_index_remappings, |
306 |
| - module._table_name_to_tensors_list_pruning_index_remappings, |
307 |
| - ), |
308 |
| - ]: |
309 |
| - for ( |
310 |
| - table_name, |
311 |
| - sharded_t, |
312 |
| - ) in dict_sharded_t.items(): |
313 |
| - destination[ |
314 |
| - f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}" |
315 |
| - ] = sharded_t |
316 |
| - for ( |
317 |
| - table_name, |
318 |
| - t_list, |
319 |
| - ) in dict_t_list.items(): |
320 |
| - destination[ |
321 |
| - f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}" |
322 |
| - ] = t_list |
323 |
| - |
324 |
| - self._register_state_dict_hook(post_state_dict_hook) |
| 260 | + self._register_state_dict_hook( |
| 261 | + partial(post_state_dict_hook, tables_weights_prefix=tables_weights_prefix) |
| 262 | + ) |
325 | 263 |
|
326 | 264 | def _load_from_state_dict(
|
327 | 265 | # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"]
|
|
0 commit comments