From 1f04b46dd02ee8c5506ec76911661c0667a4c9f8 Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Mon, 20 Sep 2021 18:39:13 -0700 Subject: [PATCH] fix parameter placement Summary: fix tensor placement where the remote device should receive {rank, local_rank} Differential Revision: D31072120 fbshipit-source-id: 13de19a31a5cafeef280ed7b38a8372a4038fe89 --- distributed/embedding_lookup.py | 1 + distributed/planner/embedding_planner.py | 1 + distributed/planner/parameter_sharding.py | 50 +++++++++++++++-------- distributed/planner/utils.py | 4 +- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/distributed/embedding_lookup.py b/distributed/embedding_lookup.py index f98e18662..224774350 100644 --- a/distributed/embedding_lookup.py +++ b/distributed/embedding_lookup.py @@ -738,6 +738,7 @@ def to_embedding_location( feature_table_map=self._feature_table_map, pooling_mode=self._pooling, weights_precision=to_sparse_type(config.data_type), + device=device, **fused_params, ) ) diff --git a/distributed/planner/embedding_planner.py b/distributed/planner/embedding_planner.py index e1041e405..63e74a333 100644 --- a/distributed/planner/embedding_planner.py +++ b/distributed/planner/embedding_planner.py @@ -124,6 +124,7 @@ def plan( return to_plan( [param_info for _, param_info in placed_param_infos], + self._device, self._world_size, self._local_size, ) diff --git a/distributed/planner/parameter_sharding.py b/distributed/planner/parameter_sharding.py index cacc99883..78fbe989a 100644 --- a/distributed/planner/parameter_sharding.py +++ b/distributed/planner/parameter_sharding.py @@ -2,7 +2,7 @@ import abc import itertools import math -from typing import List, Tuple, Optional +from typing import List, Tuple import torch from torch.distributed._sharding_spec import EnumerableShardingSpec, ShardMetadata @@ -63,34 +63,46 @@ def _rw_shard_table_rows(hash_size: int, world_size: int) -> Tuple[List[int], in return (local_rows, block_size, last_rank) +def _device_placement( + device: torch.device, + rank: int, + local_size: int, +) -> str: + param_device = device + if device.type == "cuda": + param_device = torch.device("cuda", rank % local_size) + return f"rank:{rank}/{param_device}" + + class ParameterShardingFactory(abc.ABC): @staticmethod def shard_parameters( param_info: ParameterInfo, + device: torch.device, world_size: int, - local_size: Optional[int], + local_size: int, ) -> ParameterSharding: sharding_option = param_info.sharding_options[0] sharding_type = sharding_option.sharding_type if sharding_type == ShardingType.TABLE_WISE.value: parameter_sharding = TwParameterSharding.shard_parameters( - param_info, world_size, local_size + param_info, device, world_size, local_size ) elif sharding_type == ShardingType.ROW_WISE.value: parameter_sharding = RwParameterSharding.shard_parameters( - param_info, world_size, local_size + param_info, device, world_size, local_size ) elif sharding_type == ShardingType.TABLE_ROW_WISE.value: parameter_sharding = TwRwParameterSharding.shard_parameters( - param_info, world_size, local_size + param_info, device, world_size, local_size ) elif sharding_type == ShardingType.COLUMN_WISE.value: parameter_sharding = CwParameterSharding.shard_parameters( - param_info, world_size, local_size + param_info, device, world_size, local_size ) elif sharding_type == ShardingType.DATA_PARALLEL.value: parameter_sharding = DpParameterSharding.shard_parameters( - param_info, world_size, local_size + param_info, device, world_size, local_size ) else: raise ValueError( @@ -104,8 +116,9 @@ class TwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, + device: torch.device, world_size: int, - local_size: Optional[int], + local_size: int, ) -> ParameterSharding: sharding_option = param_info.sharding_options[0] tensor = param_info.param @@ -118,7 +131,7 @@ def shard_parameters( tensor.shape[1], ], shard_offsets=[0, 0], - placement=f"rank:{rank}/cuda:{rank}", + placement=_device_placement(device, rank, local_size), ) ] return ParameterSharding( @@ -134,8 +147,9 @@ class RwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, + device: torch.device, world_size: int, - local_size: Optional[int], + local_size: int, ) -> ParameterSharding: sharding_option = param_info.sharding_options[0] tensor = param_info.param @@ -149,7 +163,7 @@ def shard_parameters( tensor.shape[1], ], shard_offsets=[block_size * min(rank, last_rank), 0], - placement=f"rank:{rank}/cuda:{rank}", + placement=_device_placement(device, rank, local_size), ) for rank in range(world_size) ] @@ -166,8 +180,9 @@ class TwRwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, + device: torch.device, world_size: int, - local_size: Optional[int], + local_size: int, ) -> ParameterSharding: sharding_option = param_info.sharding_options[0] tensor = param_info.param @@ -179,7 +194,6 @@ def shard_parameters( hash_size=tensor.shape[0], embedding_dim=tensor.shape[1], world_size=world_size, - # pyre-fixme [6] local_size=local_size, ) shards = [ @@ -189,7 +203,7 @@ def shard_parameters( local_cols[rank], ], shard_offsets=[local_row_offsets[rank], 0], - placement=f"rank:{rank}/cuda:{rank}", + placement=_device_placement(device, rank, local_size), ) for rank in range(table_node * local_size, (table_node + 1) * local_size) ] @@ -207,8 +221,9 @@ class CwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, + device: torch.device, world_size: int, - local_size: Optional[int], + local_size: int, ) -> ParameterSharding: sharding_option = param_info.sharding_options[0] tensor = param_info.param @@ -235,7 +250,7 @@ def shard_parameters( merged_sizes[i], ], shard_offsets=[0, offsets[i]], - placement=f"rank:{rank}/cuda:{rank}", + placement=_device_placement(device, rank, local_size), ) for i, rank in enumerate(merged_ranks) ] @@ -252,8 +267,9 @@ class DpParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, + device: torch.device, world_size: int, - local_size: Optional[int], + local_size: int, ) -> ParameterSharding: sharding_option = param_info.sharding_options[0] return ParameterSharding( diff --git a/distributed/planner/utils.py b/distributed/planner/utils.py index 2a5173d71..9c14202e6 100644 --- a/distributed/planner/utils.py +++ b/distributed/planner/utils.py @@ -214,14 +214,16 @@ def param_sort_key( def to_plan( parameter_infos: List[ParameterInfo], + device: torch.device, world_size: int, - local_size: Optional[int], + local_size: int, ) -> ShardingPlan: plan = {} for parameter_info in parameter_infos: shards = plan.get(parameter_info.prefix, {}) shards[parameter_info.name] = ParameterShardingFactory.shard_parameters( param_info=parameter_info, + device=device, world_size=world_size, local_size=local_size, )