@@ -492,7 +492,7 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
492
492
self ._weight_init_mins : List [float ] = []
493
493
self ._weight_init_maxs : List [float ] = []
494
494
self ._num_embeddings : List [int ] = []
495
- self ._embedding_dims : List [int ] = []
495
+ self ._local_cols : List [int ] = []
496
496
self ._feature_table_map : List [int ] = []
497
497
self ._emb_names : List [str ] = []
498
498
self ._lengths_per_emb : List [int ] = []
@@ -503,7 +503,7 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode:
503
503
self ._weight_init_mins .append (config .get_weight_init_min ())
504
504
self ._weight_init_maxs .append (config .get_weight_init_max ())
505
505
self ._num_embeddings .append (config .num_embeddings )
506
- self ._embedding_dims .append (config .local_cols )
506
+ self ._local_cols .append (config .local_cols )
507
507
self ._feature_table_map .extend ([idx ] * config .num_features ())
508
508
for feature_name in config .feature_names :
509
509
if feature_name not in shared_feature :
@@ -526,7 +526,7 @@ def init_parameters(self) -> None:
526
526
)
527
527
for (rows , emb_dim , weight_init_min , weight_init_max , param ) in zip (
528
528
self ._local_rows ,
529
- self ._embedding_dims ,
529
+ self ._local_cols ,
530
530
self ._weight_init_mins ,
531
531
self ._weight_init_maxs ,
532
532
self .emb_module .split_embedding_weights (),
@@ -616,13 +616,28 @@ def __init__(
616
616
def to_rowwise_sharded_metadata (
617
617
local_metadata : ShardMetadata ,
618
618
global_metadata : ShardedTensorMetadata ,
619
+ sharding_dim : int ,
619
620
) -> Tuple [ShardMetadata , ShardedTensorMetadata ]:
620
621
rw_shards : List [ShardMetadata ] = []
621
622
rw_local_shard : ShardMetadata = local_metadata
622
- for shard in global_metadata .shards_metadata :
623
+ shards_metadata = global_metadata .shards_metadata
624
+ # column-wise sharding
625
+ # sort the metadata based on column offset and
626
+ # we construct the momentum tensor in row-wise sharded way
627
+ if sharding_dim == 1 :
628
+ shards_metadata = sorted (
629
+ shards_metadata , key = lambda shard : shard .shard_offsets [1 ]
630
+ )
631
+
632
+ for idx , shard in enumerate (shards_metadata ):
633
+ offset = shard .shard_offsets [0 ]
634
+ # for column-wise sharding, we still create row-wise sharded metadata for optimizer
635
+ # manually create a row-wise offset
636
+ if sharding_dim == 1 :
637
+ offset = idx * shard .shard_lengths [0 ]
623
638
rw_shard = ShardMetadata (
624
639
shard_lengths = [shard .shard_lengths [0 ]],
625
- shard_offsets = [shard . shard_offsets [ 0 ] ],
640
+ shard_offsets = [offset ],
626
641
placement = shard .placement ,
627
642
)
628
643
@@ -638,10 +653,10 @@ def to_rowwise_sharded_metadata(
638
653
memory_format = global_metadata .tensor_properties .memory_format ,
639
654
pin_memory = global_metadata .tensor_properties .pin_memory ,
640
655
)
641
-
656
+ len_rw_shards = len ( shards_metadata ) if sharding_dim == 1 else 1
642
657
rw_metadata = ShardedTensorMetadata (
643
658
shards_metadata = rw_shards ,
644
- size = torch .Size ([global_metadata .size [0 ]]),
659
+ size = torch .Size ([global_metadata .size [0 ] * len_rw_shards ]),
645
660
tensor_properties = tensor_properties ,
646
661
)
647
662
return rw_local_shard , rw_metadata
@@ -673,10 +688,15 @@ def to_rowwise_sharded_metadata(
673
688
state [weight ] = {}
674
689
# momentum1
675
690
assert table_config .local_rows == optimizer_states [0 ].size (0 )
691
+ sharding_dim = (
692
+ 1 if table_config .local_cols != table_config .embedding_dim else 0
693
+ )
676
694
momentum1_key = f"{ table_config .name } .momentum1"
677
695
if optimizer_states [0 ].dim () == 1 :
678
696
(local_metadata , sharded_tensor_metadata ) = to_rowwise_sharded_metadata (
679
- table_config .local_metadata , table_config .global_metadata
697
+ table_config .local_metadata ,
698
+ table_config .global_metadata ,
699
+ sharding_dim ,
680
700
)
681
701
else :
682
702
(local_metadata , sharded_tensor_metadata ) = (
@@ -699,7 +719,9 @@ def to_rowwise_sharded_metadata(
699
719
local_metadata ,
700
720
sharded_tensor_metadata ,
701
721
) = to_rowwise_sharded_metadata (
702
- table_config .local_metadata , table_config .global_metadata
722
+ table_config .local_metadata ,
723
+ table_config .global_metadata ,
724
+ sharding_dim ,
703
725
)
704
726
else :
705
727
(local_metadata , sharded_tensor_metadata ) = (
@@ -769,9 +791,7 @@ def to_embedding_location(
769
791
self ._emb_module : SplitTableBatchedEmbeddingBagsCodegen = (
770
792
SplitTableBatchedEmbeddingBagsCodegen (
771
793
embedding_specs = list (
772
- zip (
773
- self ._local_rows , self ._embedding_dims , managed , compute_devices
774
- )
794
+ zip (self ._local_rows , self ._local_cols , managed , compute_devices )
775
795
),
776
796
feature_table_map = self ._feature_table_map ,
777
797
pooling_mode = self ._pooling ,
@@ -822,7 +842,7 @@ def __init__(
822
842
823
843
self ._emb_module : DenseTableBatchedEmbeddingBagsCodegen = (
824
844
DenseTableBatchedEmbeddingBagsCodegen (
825
- list (zip (self ._local_rows , self ._embedding_dims )),
845
+ list (zip (self ._local_rows , self ._local_cols )),
826
846
feature_table_map = self ._feature_table_map ,
827
847
pooling_mode = self ._pooling ,
828
848
use_cpu = device is None or device .type == "cpu" ,
0 commit comments