diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 608544e12..8526d0c6b 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -491,18 +491,20 @@ def shard( def sharding_types(self) -> List[str]: return [ShardingType.DATA_PARALLEL.value] - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [ EmbeddingComputeKernel.BATCHED_QUANT.value, ] def storage_usage( - self, tensor: torch.Tensor, device: torch.device, compute_kernel: str + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str ) -> Dict[str, int]: tensor_bytes = tensor.numel() * tensor.element_size() + tensor.shape[0] * 4 - assert device.type in {"cuda", "cpu"} + assert compute_device_type in {"cuda", "cpu"} storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} - return {storage_map[device.type].value: tensor_bytes} + return {storage_map[compute_device_type].value: tensor_bytes} def shardable_parameters( self, module: QuantEmbeddingBagCollection diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 9b5e30d07..f88f13fcf 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -174,7 +174,9 @@ def sharding_types(self) -> List[str]: return types - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: ret = [ EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.BATCHED_DENSE.value, @@ -184,7 +186,7 @@ def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str] EmbeddingComputeKernel.BATCHED_FUSED.value, EmbeddingComputeKernel.SPARSE.value, ] - if device.type in {"cuda"}: + if compute_device_type in {"cuda"}: ret += [ EmbeddingComputeKernel.BATCHED_FUSED_UVM.value, EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value, @@ -196,7 +198,7 @@ def fused_params(self) -> Optional[Dict[str, Any]]: return self._fused_params def storage_usage( - self, tensor: torch.Tensor, device: torch.device, compute_kernel: str + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str ) -> Dict[str, int]: """ List of system resources and corresponding usage given a compute device and @@ -207,12 +209,12 @@ def storage_usage( EmbeddingComputeKernel.BATCHED_FUSED_UVM.value, EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value, }: - assert device.type in {"cuda"} + assert compute_device_type in {"cuda"} return {ParameterStorage.DDR.value: tensor_bytes} else: - assert device.type in {"cuda", "cpu"} + assert compute_device_type in {"cuda", "cpu"} storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} return { - storage_map[device.type].value: tensor.element_size() + storage_map[compute_device_type].value: tensor.element_size() * tensor.nelement() } diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 6e04694e0..178571fd3 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -95,7 +95,7 @@ def __init__( # 2. Call ShardingPlanner.collective_plan passing all found modules and corresponding sharders. if plan is None: - planner = EmbeddingShardingPlanner(self._env.world_size, self.device) + planner = EmbeddingShardingPlanner(self._env.world_size, self.device.type) pg = self._env.process_group if pg is not None: plan = planner.collective_plan(module, sharders, pg) diff --git a/torchrec/distributed/planner/embedding_planner.py b/torchrec/distributed/planner/embedding_planner.py index b27367728..9ba6bd95f 100644 --- a/torchrec/distributed/planner/embedding_planner.py +++ b/torchrec/distributed/planner/embedding_planner.py @@ -50,7 +50,7 @@ class EmbeddingShardingPlanner(ShardingPlanner): def __init__( self, world_size: int, - device: torch.device, + compute_device_type: str = "cuda", hints: Optional[Dict[str, ParameterHints]] = None, input_stats: Optional[Dict[str, ParameterInputStats]] = None, storage: Optional[Dict[str, int]] = None, @@ -62,7 +62,7 @@ def __init__( self._input_stats: Dict[str, ParameterInputStats] = ( input_stats if input_stats else {} ) - self._device = device + self._compute_device_type = compute_device_type if cost_functions is None: self._cost_functions: List[Callable[[CostInput], int]] = [ @@ -71,7 +71,9 @@ def __init__( else: self._cost_functions = cost_functions - self._topology: Topology = get_topology(world_size, device, storage) + self._topology: Topology = get_topology( + world_size, compute_device_type, storage + ) self._counter: int = 1 def collective_plan( @@ -131,7 +133,7 @@ def plan( sharding_plan = to_plan( param_infos, - self._device, + self._compute_device_type, self._world_size, self._local_size, ) @@ -419,11 +421,14 @@ def _get_param_infos( name, param, sharding_type ) for compute_kernel in self._filter_compute_kernels( - name, sharder.compute_kernels(sharding_type, self._device) + name, + sharder.compute_kernels( + sharding_type, self._compute_device_type + ), ): cost_input = CostInput( param=param, - device=self._device, + compute_device_type=self._compute_device_type, compute_kernel=compute_kernel, sharding_type=sharding_type, input_stats=self._input_stats.get(name, None), @@ -440,7 +445,7 @@ def _get_param_infos( sharding_type=sharding_type, compute_kernel=compute_kernel, storage_usage=sharder.storage_usage( - param, self._device, compute_kernel + param, self._compute_device_type, compute_kernel ), _num_col_wise_shards=num_col_wise_shards, col_wise_shard_dim=shard_size, diff --git a/torchrec/distributed/planner/parameter_sharding.py b/torchrec/distributed/planner/parameter_sharding.py index 4d35287b4..f7f53ad04 100644 --- a/torchrec/distributed/planner/parameter_sharding.py +++ b/torchrec/distributed/planner/parameter_sharding.py @@ -64,12 +64,12 @@ def _rw_shard_table_rows(hash_size: int, world_size: int) -> Tuple[List[int], in def _device_placement( - device: torch.device, + compute_device_type: str, rank: int, local_size: int, ) -> str: - param_device = device - if device.type == "cuda": + param_device = torch.device("cpu") + if compute_device_type == "cuda": param_device = torch.device("cuda", rank % local_size) return f"rank:{rank}/{param_device}" @@ -78,7 +78,7 @@ class ParameterShardingFactory(abc.ABC): @staticmethod def shard_parameters( param_info: ParameterInfo, - device: torch.device, + compute_device_type: str, world_size: int, local_size: int, ) -> ParameterSharding: @@ -86,23 +86,23 @@ def shard_parameters( sharding_type = sharding_option.sharding_type if sharding_type == ShardingType.TABLE_WISE.value: parameter_sharding = TwParameterSharding.shard_parameters( - param_info, device, world_size, local_size + param_info, compute_device_type, world_size, local_size ) elif sharding_type == ShardingType.ROW_WISE.value: parameter_sharding = RwParameterSharding.shard_parameters( - param_info, device, world_size, local_size + param_info, compute_device_type, world_size, local_size ) elif sharding_type == ShardingType.TABLE_ROW_WISE.value: parameter_sharding = TwRwParameterSharding.shard_parameters( - param_info, device, world_size, local_size + param_info, compute_device_type, world_size, local_size ) elif sharding_type == ShardingType.COLUMN_WISE.value: parameter_sharding = CwParameterSharding.shard_parameters( - param_info, device, world_size, local_size + param_info, compute_device_type, world_size, local_size ) elif sharding_type == ShardingType.DATA_PARALLEL.value: parameter_sharding = DpParameterSharding.shard_parameters( - param_info, device, world_size, local_size + param_info, compute_device_type, world_size, local_size ) else: raise ValueError( @@ -116,7 +116,7 @@ class TwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, - device: torch.device, + compute_device_type: str, world_size: int, local_size: int, ) -> ParameterSharding: @@ -131,7 +131,7 @@ def shard_parameters( tensor.shape[1], ], shard_offsets=[0, 0], - placement=_device_placement(device, rank, local_size), + placement=_device_placement(compute_device_type, rank, local_size), ) ] return ParameterSharding( @@ -147,7 +147,7 @@ class RwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, - device: torch.device, + compute_device_type: str, world_size: int, local_size: int, ) -> ParameterSharding: @@ -163,7 +163,7 @@ def shard_parameters( tensor.shape[1], ], shard_offsets=[block_size * min(rank, last_rank), 0], - placement=_device_placement(device, rank, local_size), + placement=_device_placement(compute_device_type, rank, local_size), ) for rank in range(world_size) ] @@ -180,7 +180,7 @@ class TwRwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, - device: torch.device, + compute_device_type: str, world_size: int, local_size: int, ) -> ParameterSharding: @@ -203,7 +203,7 @@ def shard_parameters( local_cols[rank], ], shard_offsets=[local_row_offsets[rank], 0], - placement=_device_placement(device, rank, local_size), + placement=_device_placement(compute_device_type, rank, local_size), ) for rank in range(table_node * local_size, (table_node + 1) * local_size) ] @@ -221,7 +221,7 @@ class CwParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, - device: torch.device, + compute_device_type: str, world_size: int, local_size: int, ) -> ParameterSharding: @@ -250,7 +250,7 @@ def shard_parameters( merged_sizes[i], ], shard_offsets=[0, offsets[i]], - placement=_device_placement(device, rank, local_size), + placement=_device_placement(compute_device_type, rank, local_size), ) for i, rank in enumerate(merged_ranks) ] @@ -267,7 +267,7 @@ class DpParameterSharding: def shard_parameters( cls, param_info: ParameterInfo, - device: torch.device, + compute_device_type: str, world_size: int, local_size: int, ) -> ParameterSharding: diff --git a/torchrec/distributed/planner/tests/test_embedding_planner.py b/torchrec/distributed/planner/tests/test_embedding_planner.py index fe5d2b29d..4adb7d822 100644 --- a/torchrec/distributed/planner/tests/test_embedding_planner.py +++ b/torchrec/distributed/planner/tests/test_embedding_planner.py @@ -4,7 +4,7 @@ from typing import List from unittest.mock import MagicMock, patch, call -import torch +from torch import distributed as dist from torch.distributed._sharding_spec import ShardMetadata, EnumerableShardingSpec from torchrec.distributed.embedding import EmbeddingBagCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel @@ -29,7 +29,9 @@ def sharding_types(self) -> List[str]: Restricts to single impl. """ - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [EmbeddingComputeKernel.DENSE.value] @@ -45,7 +47,9 @@ def sharding_types(self) -> List[str]: Restricts to single impl. """ - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [EmbeddingComputeKernel.DENSE.value] @@ -58,7 +62,9 @@ def sharding_types(self) -> List[str]: Restricts to single impl. """ - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [EmbeddingComputeKernel.DENSE.value] @@ -74,7 +80,9 @@ def sharding_types(self) -> List[str]: Restricts to single impl. """ - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [EmbeddingComputeKernel.DENSE.value] @@ -91,14 +99,16 @@ def sharding_types(self) -> List[str]: Restricts to single impl. """ - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [EmbeddingComputeKernel.DENSE.value] class TestEmbeddingPlanner(unittest.TestCase): def setUp(self) -> None: # Mocks - self.device = torch.device("cuda:0") + self.compute_device_type = "cuda" @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) def test_allocation_planner_balanced(self, mock_logger: MagicMock) -> None: @@ -188,7 +198,9 @@ def test_allocation_planner_balanced(self, mock_logger: MagicMock) -> None: model = TestSparseNN(tables=tables, weighted_tables=[]) world_size = 2 planner = EmbeddingShardingPlanner( - world_size=world_size, device=self.device, storage=storage + world_size=world_size, + compute_device_type=self.compute_device_type, + storage=storage, ) sharders = [TWSharder()] @@ -303,7 +315,9 @@ def test_allocation_planner_one_big_rest_small( world_size = 2 planner = EmbeddingShardingPlanner( - world_size=world_size, device=self.device, storage=storage + world_size=world_size, + compute_device_type=self.compute_device_type, + storage=storage, ) sharders = [DPTWSharder()] # pyre-ignore [6] @@ -396,7 +410,7 @@ def test_allocation_planner_two_big_rest_small( world_size = 2 planner = EmbeddingShardingPlanner( world_size=world_size, - device=self.device, + compute_device_type=self.compute_device_type, # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd # param but got `Dict[str, float]`. storage=storage, @@ -538,7 +552,7 @@ def test_allocation_planner_rw_two_big_rest_small( world_size = 4 planner = EmbeddingShardingPlanner( world_size=world_size, - device=self.device, + compute_device_type=self.compute_device_type, # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd # param but got `Dict[str, float]`. storage=storage, @@ -608,7 +622,7 @@ def test_allocation_planner_cw_balanced(self, mock_logger: MagicMock) -> None: world_size = 2 planner = EmbeddingShardingPlanner( world_size=world_size, - device=self.device, + compute_device_type=self.compute_device_type, storage=storage, hints={ "table_0": ParameterHints( @@ -724,7 +738,7 @@ def test_allocation_planner_cw_two_big_rest_small_with_residual( world_size = 4 planner = EmbeddingShardingPlanner( world_size=world_size, - device=self.device, + compute_device_type=self.compute_device_type, # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd # param but got `Dict[str, float]`. storage=storage, diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index df3ca2e7f..ba444cc7c 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -103,7 +103,7 @@ def __lt__(self, other: "ShardingOption") -> bool: @dataclass class CostInput: param: torch.Tensor - device: torch.device + compute_device_type: str compute_kernel: str sharding_type: str input_stats: Optional[ParameterInputStats] diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index 56ebaf8e4..cfe06c250 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -210,7 +210,7 @@ def param_sort_key( def to_plan( parameter_infos: List[ParameterInfo], - device: torch.device, + compute_device_type: str, world_size: int, local_size: int, ) -> ShardingPlan: @@ -219,7 +219,7 @@ def to_plan( shards = plan.get(parameter_info.prefix, {}) shards[parameter_info.name] = ParameterShardingFactory.shard_parameters( param_info=parameter_info, - device=device, + compute_device_type=compute_device_type, world_size=world_size, local_size=local_size, ) @@ -253,13 +253,18 @@ def _get_storage( def get_topology( world_size: int, - device: torch.device, + compute_device_type: str, storage_in_gb: Optional[Dict[str, int]], ) -> Topology: devices_per_host = get_local_size(world_size) num_hosts = get_num_groups(world_size) - compute_device = device.type - storage = _get_storage(device, storage_in_gb) + compute_device = compute_device_type + sample_device = ( + torch.device("cuda", 0) + if compute_device_type == "cuda" + else torch.device("cpu") + ) + storage = _get_storage(sample_device, storage_in_gb) topology = Topology( hosts=[ HostInfo( diff --git a/torchrec/distributed/tests/test_model.py b/torchrec/distributed/tests/test_model.py index f381e4ee0..8c1232698 100644 --- a/torchrec/distributed/tests/test_model.py +++ b/torchrec/distributed/tests/test_model.py @@ -447,7 +447,9 @@ def sharding_types(self) -> List[str]: Restricts to single impl. """ - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [self._kernel_type] @property @@ -472,7 +474,9 @@ def sharding_types(self) -> List[str]: Restricts to single impl. """ - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [self._kernel_type] @property diff --git a/torchrec/distributed/tests/test_model_parallel_base.py b/torchrec/distributed/tests/test_model_parallel_base.py index ae42ed300..58974f937 100644 --- a/torchrec/distributed/tests/test_model_parallel_base.py +++ b/torchrec/distributed/tests/test_model_parallel_base.py @@ -90,7 +90,7 @@ def _test_sharding_single_rank( sparse_device=torch.device("meta"), ) - planner = EmbeddingShardingPlanner(world_size, device, hints) + planner = EmbeddingShardingPlanner(world_size, device.type, hints) plan = planner.collective_plan(local_model, sharders, pg) local_model = DistributedModelParallel( diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index c086c5ac5..6c300b1b6 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -41,7 +41,9 @@ def __init__(self, sharding_type: str, kernel_type: str) -> None: def sharding_types(self) -> List[str]: return [self._sharding_type] - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [self._kernel_type] diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index fa0151442..2acf6c53c 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -69,7 +69,9 @@ def sharding_types(self) -> List[str]: ShardingType.TABLE_WISE.value, ] - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: return [EmbeddingComputeKernel.DENSE.value] diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 55aee5b42..db1d868a8 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -296,11 +296,11 @@ class ShardingEnv: """ def __init__( - self, world_size: int, rank: int, pg: Optional[dist.ProcessGroup] + self, world_size: int, rank: int, pg: Optional[dist.ProcessGroup] = None ) -> None: self.world_size = world_size self.rank = rank - self.process_group = pg + self.process_group: Optional[dist.ProcessGroup] = pg @classmethod def from_process_group(cls, pg: dist.ProcessGroup) -> "ShardingEnv": @@ -439,7 +439,9 @@ def sharding_types(self) -> List[str]: """ return [ShardingType.DATA_PARALLEL.value] - def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str]: + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: """ List of supported compute kernels for a given sharding_type and compute device. """ @@ -447,17 +449,18 @@ def compute_kernels(self, sharding_type: str, device: torch.device) -> List[str] return [ComputeKernel.DEFAULT.value] def storage_usage( - self, tensor: torch.Tensor, device: torch.device, compute_kernel: str + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str ) -> Dict[str, int]: """ List of system resources and corresponding usage given a compute device and compute kernel """ - assert device.type in {"cuda", "cpu"} + assert compute_device_type in {"cuda", "cpu"} storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} return { - storage_map[device.type].value: tensor.element_size() * tensor.nelement() + storage_map[compute_device_type].value: tensor.element_size() + * tensor.nelement() }