Skip to content

Commit df0dd90

Browse files
Leon Gaofacebook-github-bot
Leon Gao
authored andcommitted
col-wise ads config (#6)
Summary: Pull Request resolved: #6 Reviewed By: dstaay-fb Differential Revision: D31058649 fbshipit-source-id: 06d3aa668c170bebe4e2cf7f12551ca7129504fb
1 parent 1086953 commit df0dd90

File tree

3 files changed

+37
-18
lines changed

3 files changed

+37
-18
lines changed

distributed/embedding_lookup.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
492492
self._weight_init_mins: List[float] = []
493493
self._weight_init_maxs: List[float] = []
494494
self._num_embeddings: List[int] = []
495-
self._embedding_dims: List[int] = []
495+
self._local_cols: List[int] = []
496496
self._feature_table_map: List[int] = []
497497
self._emb_names: List[str] = []
498498
self._lengths_per_emb: List[int] = []
@@ -503,7 +503,7 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
503503
self._weight_init_mins.append(config.get_weight_init_min())
504504
self._weight_init_maxs.append(config.get_weight_init_max())
505505
self._num_embeddings.append(config.num_embeddings)
506-
self._embedding_dims.append(config.local_cols)
506+
self._local_cols.append(config.local_cols)
507507
self._feature_table_map.extend([idx] * config.num_features())
508508
for feature_name in config.feature_names:
509509
if feature_name not in shared_feature:
@@ -526,7 +526,7 @@ def init_parameters(self) -> None:
526526
)
527527
for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip(
528528
self._local_rows,
529-
self._embedding_dims,
529+
self._local_cols,
530530
self._weight_init_mins,
531531
self._weight_init_maxs,
532532
self.emb_module.split_embedding_weights(),
@@ -616,13 +616,28 @@ def __init__(
616616
def to_rowwise_sharded_metadata(
617617
local_metadata: ShardMetadata,
618618
global_metadata: ShardedTensorMetadata,
619+
sharding_dim: int,
619620
) -> Tuple[ShardMetadata, ShardedTensorMetadata]:
620621
rw_shards: List[ShardMetadata] = []
621622
rw_local_shard: ShardMetadata = local_metadata
622-
for shard in global_metadata.shards_metadata:
623+
shards_metadata = global_metadata.shards_metadata
624+
# column-wise sharding
625+
# sort the metadata based on column offset and
626+
# we construct the momentum tensor in row-wise sharded way
627+
if sharding_dim == 1:
628+
shards_metadata = sorted(
629+
shards_metadata, key=lambda shard: shard.shard_offsets[1]
630+
)
631+
632+
for idx, shard in enumerate(shards_metadata):
633+
offset = shard.shard_offsets[0]
634+
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
635+
# manually create a row-wise offset
636+
if sharding_dim == 1:
637+
offset = idx * shard.shard_lengths[0]
623638
rw_shard = ShardMetadata(
624639
shard_lengths=[shard.shard_lengths[0]],
625-
shard_offsets=[shard.shard_offsets[0]],
640+
shard_offsets=[offset],
626641
placement=shard.placement,
627642
)
628643

@@ -638,10 +653,10 @@ def to_rowwise_sharded_metadata(
638653
memory_format=global_metadata.tensor_properties.memory_format,
639654
pin_memory=global_metadata.tensor_properties.pin_memory,
640655
)
641-
656+
len_rw_shards = len(shards_metadata) if sharding_dim == 1 else 1
642657
rw_metadata = ShardedTensorMetadata(
643658
shards_metadata=rw_shards,
644-
size=torch.Size([global_metadata.size[0]]),
659+
size=torch.Size([global_metadata.size[0] * len_rw_shards]),
645660
tensor_properties=tensor_properties,
646661
)
647662
return rw_local_shard, rw_metadata
@@ -673,10 +688,15 @@ def to_rowwise_sharded_metadata(
673688
state[weight] = {}
674689
# momentum1
675690
assert table_config.local_rows == optimizer_states[0].size(0)
691+
sharding_dim = (
692+
1 if table_config.local_cols != table_config.embedding_dim else 0
693+
)
676694
momentum1_key = f"{table_config.name}.momentum1"
677695
if optimizer_states[0].dim() == 1:
678696
(local_metadata, sharded_tensor_metadata) = to_rowwise_sharded_metadata(
679-
table_config.local_metadata, table_config.global_metadata
697+
table_config.local_metadata,
698+
table_config.global_metadata,
699+
sharding_dim,
680700
)
681701
else:
682702
(local_metadata, sharded_tensor_metadata) = (
@@ -699,7 +719,9 @@ def to_rowwise_sharded_metadata(
699719
local_metadata,
700720
sharded_tensor_metadata,
701721
) = to_rowwise_sharded_metadata(
702-
table_config.local_metadata, table_config.global_metadata
722+
table_config.local_metadata,
723+
table_config.global_metadata,
724+
sharding_dim,
703725
)
704726
else:
705727
(local_metadata, sharded_tensor_metadata) = (
@@ -769,9 +791,7 @@ def to_embedding_location(
769791
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
770792
SplitTableBatchedEmbeddingBagsCodegen(
771793
embedding_specs=list(
772-
zip(
773-
self._local_rows, self._embedding_dims, managed, compute_devices
774-
)
794+
zip(self._local_rows, self._local_cols, managed, compute_devices)
775795
),
776796
feature_table_map=self._feature_table_map,
777797
pooling_mode=self._pooling,
@@ -822,7 +842,7 @@ def __init__(
822842

823843
self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
824844
DenseTableBatchedEmbeddingBagsCodegen(
825-
list(zip(self._local_rows, self._embedding_dims)),
845+
list(zip(self._local_rows, self._local_cols)),
826846
feature_table_map=self._feature_table_map,
827847
pooling_mode=self._pooling,
828848
use_cpu=device is None or device.type == "cpu",

distributed/planner/embedding_planner.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
deallocate_param,
3232
param_sort_key,
3333
to_plan,
34-
MIN_DIM,
3534
)
3635
from torchrec.distributed.types import (
3736
ShardingPlan,
@@ -398,13 +397,10 @@ def _get_num_col_wise_shards(
398397
col_wise_shard_dim = (
399398
col_wise_shard_dim_hint
400399
if col_wise_shard_dim_hint is not None
401-
else MIN_DIM
400+
else param.shape[1]
402401
)
403402
# column-wise shard the weights
404403
num_col_wise_shards, residual = divmod(param.shape[1], col_wise_shard_dim)
405-
assert (
406-
num_col_wise_shards > 0
407-
), f"the table {name} cannot be column-wise sharded into shards of {col_wise_shard_dim} dimensions"
408404
if residual > 0:
409405
num_col_wise_shards += 1
410406
elif sharding_type == ShardingType.TABLE_WISE.value:

distributed/planner/tests/test_embedding_planner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def test_allocation_planner_cw_balanced(self) -> None:
551551
hints={
552552
"table_0": ParameterHints(
553553
sharding_types=[ShardingType.COLUMN_WISE.value],
554+
col_wise_shard_dim=32,
554555
),
555556
},
556557
)
@@ -653,9 +654,11 @@ def test_allocation_planner_cw_two_big_rest_small_with_residual(self) -> None:
653654
hints={
654655
"table_0": ParameterHints(
655656
sharding_types=[ShardingType.COLUMN_WISE.value],
657+
col_wise_shard_dim=32,
656658
),
657659
"table_1": ParameterHints(
658660
sharding_types=[ShardingType.COLUMN_WISE.value],
661+
col_wise_shard_dim=32,
659662
),
660663
},
661664
)

0 commit comments

Comments
 (0)