From fd386e0c00fcf1cbb97989d4bd42ebc5820bf07e Mon Sep 17 00:00:00 2001 From: "leongao@fb.com" Date: Tue, 28 Sep 2021 11:23:26 -0700 Subject: [PATCH 1/2] ads ctr_mbl_feed 10x jun v0 2021h2 model (#4) Summary: Pull Request resolved: https://github.com/fairinternal/torchrec/pull/4 Differential Revision: D30878814 fbshipit-source-id: 94a3167e65cbc05fbc8f44873042e5c3f8cd7845 --- distributed/embedding_lookup.py | 6 +++--- modules/embedding_modules.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/distributed/embedding_lookup.py b/distributed/embedding_lookup.py index 130f087f3..bf51fb1cd 100644 --- a/distributed/embedding_lookup.py +++ b/distributed/embedding_lookup.py @@ -242,7 +242,7 @@ def forward( assert sparse_features.id_list_features is not None embeddings: List[torch.Tensor] = [] id_list_features_by_group = sparse_features.id_list_features.split( - self._id_list_feature_splits + self._id_list_feature_splits, ) for emb_op, features in zip(self._emb_modules, id_list_features_by_group): embeddings.append(emb_op(features).view(-1)) @@ -896,7 +896,7 @@ def forward( if len(self._emb_modules) > 0: assert sparse_features.id_list_features is not None id_list_features_by_group = sparse_features.id_list_features.split( - self._id_list_feature_splits + self._id_list_feature_splits, ) for emb_op, features in zip(self._emb_modules, id_list_features_by_group): embeddings.append(emb_op(features).values()) @@ -904,7 +904,7 @@ def forward( assert sparse_features.id_score_list_features is not None id_score_list_features_by_group = ( sparse_features.id_score_list_features.split( - self._id_score_list_feature_splits + self._id_score_list_feature_splits, ) ) for emb_op, features in zip( diff --git a/modules/embedding_modules.py b/modules/embedding_modules.py index f1b2ef769..3111d0adf 100644 --- a/modules/embedding_modules.py +++ b/modules/embedding_modules.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn from torchrec.modules.embedding_configs import ( + DataType, EmbeddingConfig, EmbeddingBagConfig, PoolingType, @@ -108,12 +109,18 @@ def __init__( if embedding_config.name in table_names: raise ValueError(f"Duplicate table name {embedding_config.name}") table_names.add(embedding_config.name) + dtype = ( + torch.float32 + if embedding_config.data_type == DataType.FP32 + else torch.float16 + ) self.embedding_bags[embedding_config.name] = nn.EmbeddingBag( num_embeddings=embedding_config.num_embeddings, embedding_dim=embedding_config.embedding_dim, mode=_to_mode(embedding_config.pooling), device=device, include_last_offset=True, + dtype=dtype, ) if not embedding_config.feature_names: embedding_config.feature_names = [embedding_config.name] From fad101ac035fcb308d7fe97393f958cf56714434 Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Tue, 28 Sep 2021 11:23:40 -0700 Subject: [PATCH 2/2] weight init for ads (#5) Summary: Pull Request resolved: https://github.com/fairinternal/torchrec/pull/5 Reviewed By: dstaay-fb Differential Revision: D31058782 fbshipit-source-id: 101519ad274e2ce2facdc334f0dfad041a89df93 --- distributed/dp_sharding.py | 2 ++ distributed/embedding.py | 2 ++ distributed/embedding_lookup.py | 22 +++++++++++++++------- distributed/rw_sharding.py | 2 ++ distributed/tw_sharding.py | 2 ++ distributed/twrw_sharding.py | 2 ++ modules/embedding_configs.py | 17 ++++++++++++++++- 7 files changed, 41 insertions(+), 8 deletions(-) diff --git a/distributed/dp_sharding.py b/distributed/dp_sharding.py index 95f60315e..6d3c020d5 100644 --- a/distributed/dp_sharding.py +++ b/distributed/dp_sharding.py @@ -116,6 +116,8 @@ def _shard( compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), local_metadata=None, global_metadata=None, + weight_init_max=config[0].weight_init_max, + weight_init_min=config[0].weight_init_min, ) ) return tables_per_rank diff --git a/distributed/embedding.py b/distributed/embedding.py index bbce14e84..485d7a49f 100644 --- a/distributed/embedding.py +++ b/distributed/embedding.py @@ -131,6 +131,8 @@ def _create_embedding_configs_by_sharding( pooling=config.pooling, is_weighted=module.is_weighted, embedding_names=embedding_names, + weight_init_max=config.weight_init_max, + weight_init_min=config.weight_init_min, ), parameter_sharding, ) diff --git a/distributed/embedding_lookup.py b/distributed/embedding_lookup.py index bf51fb1cd..8706a3b07 100644 --- a/distributed/embedding_lookup.py +++ b/distributed/embedding_lookup.py @@ -114,8 +114,8 @@ def __init__( embedding_config.local_cols, device=device, ).uniform_( - -sqrt(1 / embedding_config.num_embeddings), - sqrt(1 / embedding_config.num_embeddings), + embedding_config.get_weight_init_min(), + embedding_config.get_weight_init_max(), ), ) ) @@ -362,8 +362,8 @@ def _to_mode(pooling: PoolingType) -> str: embedding_config.local_cols, device=device, ).uniform_( - -sqrt(1 / embedding_config.num_embeddings), - sqrt(1 / embedding_config.num_embeddings), + embedding_config.get_weight_init_min(), + embedding_config.get_weight_init_max(), ), ) ) @@ -479,6 +479,8 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode: self._pooling: PoolingMode = to_pooling_mode(config.pooling) self._local_rows: List[int] = [] + self._weight_init_mins: List[float] = [] + self._weight_init_maxs: List[float] = [] self._num_embeddings: List[int] = [] self._embedding_dims: List[int] = [] self._feature_table_map: List[int] = [] @@ -488,6 +490,8 @@ def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode: shared_feature: Dict[str, bool] = {} for idx, config in enumerate(self._config.embedding_tables): self._local_rows.append(config.local_rows) + self._weight_init_mins.append(config.get_weight_init_min()) + self._weight_init_maxs.append(config.get_weight_init_max()) self._num_embeddings.append(config.num_embeddings) self._embedding_dims.append(config.local_cols) self._feature_table_map.extend([idx] * config.num_features()) @@ -510,14 +514,18 @@ def init_parameters(self) -> None: assert len(self._num_embeddings) == len( self.emb_module.split_embedding_weights() ) - for (rows, num_emb, emb_dim, param) in zip( + for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip( self._local_rows, - self._num_embeddings, self._embedding_dims, + self._weight_init_mins, + self._weight_init_maxs, self.emb_module.split_embedding_weights(), ): assert param.shape == (rows, emb_dim) - param.data.uniform_(-sqrt(1 / num_emb), sqrt(1 / num_emb)) + param.data.uniform_( + weight_init_min, + weight_init_max, + ) def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: values = self.emb_module( diff --git a/distributed/rw_sharding.py b/distributed/rw_sharding.py index a1566ff97..380aea3eb 100644 --- a/distributed/rw_sharding.py +++ b/distributed/rw_sharding.py @@ -227,6 +227,8 @@ def _shard( compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), local_metadata=shards[rank], global_metadata=global_metadata, + weight_init_max=config[0].weight_init_max, + weight_init_min=config[0].weight_init_min, ) ) return tables_per_rank diff --git a/distributed/tw_sharding.py b/distributed/tw_sharding.py index e14948a72..adc0cd583 100644 --- a/distributed/tw_sharding.py +++ b/distributed/tw_sharding.py @@ -168,6 +168,8 @@ def _shard( compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), local_metadata=shards[0], global_metadata=global_metadata, + weight_init_max=config[0].weight_init_max, + weight_init_min=config[0].weight_init_min, ) ) return tables_per_rank diff --git a/distributed/twrw_sharding.py b/distributed/twrw_sharding.py index 8fca68e8f..834111954 100644 --- a/distributed/twrw_sharding.py +++ b/distributed/twrw_sharding.py @@ -274,6 +274,8 @@ def _shard( compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), local_metadata=shards[rank_idx], global_metadata=global_metadata, + weight_init_max=config[0].weight_init_max, + weight_init_min=config[0].weight_init_min, ) ) return tables_per_rank diff --git a/modules/embedding_configs.py b/modules/embedding_configs.py index eef5dd197..19bddc418 100644 --- a/modules/embedding_configs.py +++ b/modules/embedding_configs.py @@ -2,7 +2,8 @@ from dataclasses import dataclass, field from enum import Enum, unique -from typing import List, Dict +from math import sqrt +from typing import Optional, List, Dict @unique @@ -39,6 +40,20 @@ class BaseEmbeddingConfig: name: str = "" data_type: DataType = DataType.FP32 feature_names: List[str] = field(default_factory=list) + weight_init_max: Optional[float] = None + weight_init_min: Optional[float] = None + + def get_weight_init_max(self) -> float: + if self.weight_init_max is None: + return sqrt(1 / self.num_embeddings) + else: + return self.weight_init_max + + def get_weight_init_min(self) -> float: + if self.weight_init_min is None: + return -sqrt(1 / self.num_embeddings) + else: + return self.weight_init_min def num_features(self) -> int: return len(self.feature_names)