Skip to content

Commit 5aaabb9

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
Replacing ShardedTensor with DTensor for RW sharding
Summary: **This is the first part of migration TorchRec state dict checkpointing from ShardedTensor to DTensor. It sets up the necessary infra to support additional sharding schemes. The general approach is to keep ShardedTensor paths and remove them once all sharding types are supported on DTensor. This includes ShardingPlan and ShardedTensor dataclasses such as ShardedTensorMetadata. Those will be migrated in a separate diff with ParameterSharding** NOTE: This version of LocalShardsWrapper does not support empty shards, that is added in the next diff enabling CW. D57063512 **This diff includes:** + LocalShardsWrapper torch.tensor subclass to be used with DTensor + Changes in TorchRec state_dict load and creation to use DTensor for Row Wise path in both EmbeddingCollection and EmbeddingBagCollection + Changes to DCP to support LocalShardsWrapper for saving and reading (WriteItems and ReadItems) + Added DTensor paths to callsites where ShardedTensors are expected. **LocalShardsWrapper supports the following torch ops:** + torch.ops._c10d_functional.all_gather_into_tensor.default + aten._to_copy.default + aten.view.default + aten.equal.default + aten.detach.default With extensibility to add more as required by use cases. See https://docs.google.com/document/d/16Ptl50mGFJW2cljdF2HQ6FwsiA0scwbAbjx_4dhabJw/edit?usp=drivesdk for more info regarding design and approach. Reviewed By: XilunWu Differential Revision: D54375878
1 parent 7b73952 commit 5aaabb9

File tree

11 files changed

+341
-76
lines changed

11 files changed

+341
-76
lines changed

torchrec/distributed/composable/tests/test_embedding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616
from hypothesis import given, settings, Verbosity
17+
from torch.distributed._tensor.api import DTensor
1718
from torch.distributed.optim import (
1819
_apply_optimizer_in_backward as apply_optimizer_in_backward,
1920
)
@@ -177,6 +178,8 @@ def _test_sharding( # noqa C901
177178
)
178179
if isinstance(sharded_state, ShardedTensor):
179180
sharded_state.gather(out=sharded_param)
181+
elif isinstance(sharded_state, DTensor):
182+
sharded_param = sharded_state.full_tensor()
180183
else:
181184
sharded_param = sharded_state
182185

torchrec/distributed/composable/tests/test_embeddingbag.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn as nn
1919

2020
from hypothesis import assume, given, settings, Verbosity
21+
from torch.distributed._tensor.api import DTensor
2122
from torch.distributed.optim import (
2223
_apply_optimizer_in_backward as apply_optimizer_in_backward,
2324
)
@@ -238,7 +239,11 @@ def _test_sharding( # noqa C901
238239
if ctx.rank == 0
239240
else None
240241
)
241-
sharded_state.gather(out=out)
242+
if isinstance(sharded_state, DTensor):
243+
out = sharded_state.full_tensor()
244+
else:
245+
sharded_state.gather(out=out)
246+
242247
if ctx.rank == 0:
243248
torch.testing.assert_close(
244249
unsharded_state,

torchrec/distributed/composable/tests/test_fsdp.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch import nn
1717
from torch.distributed._composable import fully_shard
1818
from torch.distributed._shard.sharded_tensor import ShardedTensor
19+
from torch.distributed._tensor import DTensor
1920

2021
from torch.distributed.checkpoint import (
2122
FileSystemReader,
@@ -193,6 +194,10 @@ def _run( # noqa
193194
if not p.local_shards():
194195
continue
195196
p = p.local_tensor()
197+
if isinstance(p, DTensor):
198+
if not p.to_local().local_shards():
199+
continue
200+
p = p.to_local().local_shards()[0]
196201
p_sum += p.sum()
197202
p.zero_()
198203
assert p.sum() == 0
@@ -205,6 +210,10 @@ def _run( # noqa
205210
if not t.local_shards():
206211
continue
207212
t = t.local_tensor()
213+
if isinstance(t, DTensor):
214+
if not t.to_local().local_shards(): # pyre-ignore[16]
215+
continue
216+
t = t.to_local().local_shards()[0]
208217
o_sum += t.sum()
209218
t.zero_()
210219
assert t.sum() == 0
@@ -228,6 +237,10 @@ def _run( # noqa
228237
continue
229238
p = p.local_tensor()
230239
p_sum_loaded += p.sum()
240+
if isinstance(p, DTensor):
241+
if not p.to_local().local_shards():
242+
continue
243+
p = p.to_local().local_shards()[0]
231244
assert p_sum.allclose(p_sum_loaded)
232245

233246
o_sum_loaded = torch.zeros(1, device=ctx.device)
@@ -239,6 +252,10 @@ def _run( # noqa
239252
if not t.local_shards():
240253
continue
241254
t = t.local_tensor()
255+
if isinstance(t, DTensor):
256+
if not t.to_local().local_shards():
257+
continue
258+
t = t.to_local().local_shards()[0]
242259
o_sum_loaded += t.sum()
243260
assert o_sum.allclose(o_sum_loaded)
244261

torchrec/distributed/embedding.py

Lines changed: 108 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919
from torch import distributed as dist, nn
2020
from torch.autograd.profiler import record_function
21+
from torch.distributed._tensor import DTensor
22+
from torch.distributed._tensor._shards_wrapper import LocalShardsWrapper
2123
from torch.nn.parallel import DistributedDataParallel
2224
from torchrec.distributed.embedding_sharding import (
2325
EmbeddingSharding,
@@ -589,18 +591,20 @@ def _pre_load_state_dict_hook(
589591
) -> None:
590592
"""
591593
Modify the destination state_dict for model parallel
592-
to transform from ShardedTensors into tensors
594+
to transform from ShardedTensors/DTensors into tensors
593595
"""
594-
for (
595-
table_name,
596-
model_shards,
597-
) in self._model_parallel_name_to_local_shards.items():
596+
for table_name in self._model_parallel_name_to_local_shards.keys():
598597
key = f"{prefix}embeddings.{table_name}.weight"
599-
598+
# gather model shards from both DTensor and ShardedTensor maps
599+
model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[
600+
table_name
601+
]
602+
model_shards_dtensor = self._model_parallel_name_to_shards_wrapper[
603+
table_name
604+
]
600605
# If state_dict[key] is already a ShardedTensor, use its local shards
601606
if isinstance(state_dict[key], ShardedTensor):
602607
local_shards = state_dict[key].local_shards()
603-
# If no local shards, create an empty tensor
604608
if len(local_shards) == 0:
605609
state_dict[key] = torch.empty(0)
606610
else:
@@ -612,27 +616,57 @@ def _pre_load_state_dict_hook(
612616
).view(-1, dim)
613617
else:
614618
state_dict[key] = local_shards[0].tensor.view(-1, dim)
615-
else:
619+
elif isinstance(state_dict[key], DTensor):
620+
shards_wrapper = state_dict[key].to_local()
621+
local_shards = shards_wrapper.local_shards()
622+
dim = shards_wrapper.local_sizes()[0][1]
623+
if len(local_shards) == 0:
624+
state_dict[key] = torch.empty(0)
625+
elif len(local_shards) > 1:
626+
# TODO - add multiple shards on rank support
627+
raise RuntimeError(
628+
f"Multiple shards on rank is not supported for DTensor yet, got {len(local_shards)}"
629+
)
630+
else:
631+
state_dict[key] = local_shards[0].view(-1, dim)
632+
elif isinstance(state_dict[key], torch.Tensor):
616633
local_shards = []
617-
for shard in model_shards:
618-
# Extract shard size and offsets for splicing
619-
shard_sizes = shard.metadata.shard_sizes
620-
shard_offsets = shard.metadata.shard_offsets
621-
622-
# Prepare tensor by splicing and placing on appropriate device
623-
spliced_tensor = state_dict[key][
624-
shard_offsets[0] : shard_offsets[0] + shard_sizes[0],
625-
shard_offsets[1] : shard_offsets[1] + shard_sizes[1],
626-
].to(shard.tensor.get_device())
627-
628-
# Append spliced tensor into local shards
629-
local_shards.append(spliced_tensor)
630-
634+
if model_shards_sharded_tensor:
635+
# splice according to sharded tensor metadata
636+
for shard in model_shards_sharded_tensor:
637+
# Extract shard size and offsets for splicing
638+
shard_size = shard.metadata.shard_sizes
639+
shard_offset = shard.metadata.shard_offsets
640+
641+
# Prepare tensor by splicing and placing on appropriate device
642+
spliced_tensor = state_dict[key][
643+
shard_offset[0] : shard_offset[0] + shard_size[0],
644+
shard_offset[1] : shard_offset[1] + shard_size[1],
645+
]
646+
647+
# Append spliced tensor into local shards
648+
local_shards.append(spliced_tensor)
649+
elif model_shards_dtensor:
650+
# splice according to dtensor metadata
651+
for tensor, shard_offset in zip(
652+
model_shards_dtensor["local_tensors"],
653+
model_shards_dtensor["local_offsets"],
654+
):
655+
shard_size = tensor.size()
656+
spliced_tensor = state_dict[key][
657+
shard_offset[0] : shard_offset[0] + shard_size[0],
658+
shard_offset[1] : shard_offset[1] + shard_size[1],
659+
]
660+
local_shards.append(spliced_tensor)
631661
state_dict[key] = (
632662
torch.empty(0)
633663
if not local_shards
634664
else torch.cat(local_shards, dim=0)
635665
)
666+
else:
667+
raise RuntimeError(
668+
f"Unexpected state_dict key type {type(state_dict[key])} found for {key}"
669+
)
636670

637671
for lookup in self._lookups:
638672
while isinstance(lookup, DistributedDataParallel):
@@ -649,7 +683,9 @@ def _initialize_torch_state(self) -> None: # noqa
649683
for table_name in self._table_names:
650684
self.embeddings[table_name] = nn.Module()
651685
self._model_parallel_name_to_local_shards = OrderedDict()
686+
self._model_parallel_name_to_shards_wrapper = OrderedDict()
652687
self._model_parallel_name_to_sharded_tensor = OrderedDict()
688+
self._model_parallel_name_to_dtensor = OrderedDict()
653689
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
654690
for (
655691
table_name,
@@ -658,6 +694,9 @@ def _initialize_torch_state(self) -> None: # noqa
658694
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
659695
continue
660696
self._model_parallel_name_to_local_shards[table_name] = []
697+
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
698+
[("local_tensors", []), ("local_offsets", [])]
699+
)
661700
model_parallel_name_to_compute_kernel[table_name] = (
662701
parameter_sharding.compute_kernel
663702
)
@@ -679,18 +718,29 @@ def _initialize_torch_state(self) -> None: # noqa
679718
# save local_shards for transforming MP params to shardedTensor
680719
for key, v in lookup.state_dict().items():
681720
table_name = key[: -len(".weight")]
682-
self._model_parallel_name_to_local_shards[table_name].extend(
683-
v.local_shards()
684-
)
721+
if isinstance(v, DTensor):
722+
shards_wrapper = self._model_parallel_name_to_shards_wrapper[
723+
table_name
724+
]
725+
local_shards_wrapper = v._local_tensor
726+
shards_wrapper["local_tensors"].extend(local_shards_wrapper.local_shards()) # pyre-ignore[16]
727+
shards_wrapper["local_offsets"].extend(local_shards_wrapper.local_offsets()) # pyre-ignore[16]
728+
shards_wrapper["global_size"] = v.size()
729+
shards_wrapper["global_stride"] = v.stride()
730+
shards_wrapper["placements"] = v.placements
731+
elif isinstance(v, ShardedTensor):
732+
self._model_parallel_name_to_local_shards[table_name].extend(
733+
v.local_shards()
734+
)
685735
for (
686736
table_name,
687737
tbe_slice,
688738
) in lookup.named_parameters_by_table():
689739
self.embeddings[table_name].register_parameter("weight", tbe_slice)
690-
for (
691-
table_name,
692-
local_shards,
693-
) in self._model_parallel_name_to_local_shards.items():
740+
for table_name in self._model_parallel_name_to_local_shards.keys():
741+
local_shards = self._model_parallel_name_to_local_shards[table_name]
742+
shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name]
743+
694744
# for shards that don't exist on this rank, register with empty tensor
695745
if not hasattr(self.embeddings[table_name], "weight"):
696746
self.embeddings[table_name].register_parameter(
@@ -703,18 +753,34 @@ def _initialize_torch_state(self) -> None: # noqa
703753
self.embeddings[table_name].weight._in_backward_optimizers = [
704754
EmptyFusedOptimizer()
705755
]
756+
706757
if model_parallel_name_to_compute_kernel[table_name] in {
707758
EmbeddingComputeKernel.KEY_VALUE.value
708759
}:
709760
continue
710-
# created ShardedTensors once in init, use in post_state_dict_hook
711-
self._model_parallel_name_to_sharded_tensor[table_name] = (
712-
ShardedTensor._init_from_local_shards(
713-
local_shards,
714-
self._name_to_table_size[table_name],
715-
process_group=self._env.process_group,
761+
762+
if shards_wrapper_map["local_tensors"]:
763+
self._model_parallel_name_to_dtensor[table_name] = DTensor.from_local(
764+
local_tensor=LocalShardsWrapper(
765+
local_shards=shards_wrapper_map["local_tensors"],
766+
local_offsets=shards_wrapper_map["local_offsets"],
767+
),
768+
device_mesh=self._env.device_mesh,
769+
placements=shards_wrapper_map["placements"],
770+
shape=shards_wrapper_map["global_size"],
771+
stride=shards_wrapper_map["global_stride"],
772+
run_check=False,
773+
)
774+
else:
775+
# if DTensors for table do not exist, create ShardedTensor
776+
# created ShardedTensors once in init, use in post_state_dict_hook
777+
self._model_parallel_name_to_sharded_tensor[table_name] = (
778+
ShardedTensor._init_from_local_shards(
779+
local_shards,
780+
self._name_to_table_size[table_name],
781+
process_group=self._env.process_group,
782+
)
716783
)
717-
)
718784

719785
def post_state_dict_hook(
720786
module: ShardedEmbeddingCollection,
@@ -729,6 +795,12 @@ def post_state_dict_hook(
729795
) in module._model_parallel_name_to_sharded_tensor.items():
730796
destination_key = f"{prefix}embeddings.{table_name}.weight"
731797
destination[destination_key] = sharded_t
798+
for (
799+
table_name,
800+
d_tensor,
801+
) in module._model_parallel_name_to_dtensor.items():
802+
destination_key = f"{prefix}embeddings.{table_name}.weight"
803+
destination[destination_key] = d_tensor
732804

733805
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
734806
self._register_state_dict_hook(post_state_dict_hook)

torchrec/distributed/embedding_kernel.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
import torch
1616
import torch.distributed as dist
1717
from torch import nn
18+
from torch.distributed._tensor import DTensor
19+
from torch.distributed._tensor._shards_wrapper import LocalShardsWrapper
1820
from torchrec.distributed.embedding_types import (
21+
DTensorMetadata,
1922
EmbeddingComputeKernel,
2023
GroupedEmbeddingConfig,
2124
ShardedEmbeddingTable,
@@ -73,6 +76,8 @@ def get_state_dict(
7376
"""
7477
key_to_local_shards: Dict[str, List[Shard]] = defaultdict(list)
7578
key_to_global_metadata: Dict[str, ShardedTensorMetadata] = {}
79+
key_to_dtensor_metadata: Dict[str, DTensorMetadata] = {}
80+
key_to_local_tensor_shards: Dict[str, List[...]] = defaultdict(list)
7681

7782
def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
7883
return prefix + f"{embedding_table.name}.weight"
@@ -98,7 +103,16 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
98103
if qscale is not None:
99104
assert embedding_table.local_cols == param.size(1) # pyre-ignore[16]
100105

101-
if embedding_table.global_metadata is not None and pg is not None:
106+
if embedding_table.dtensor_metadata is not None and pg is not None:
107+
# DTensor path
108+
key_to_dtensor_metadata[key] = embedding_table.dtensor_metadata
109+
key_to_local_tensor_shards[key].append(
110+
[
111+
param,
112+
embedding_table.local_metadata.shard_offsets, # pyre-ignore[16]
113+
]
114+
)
115+
elif embedding_table.global_metadata is not None and pg is not None:
102116
# set additional field of sharded tensor based on local tensor properties
103117
embedding_table.global_metadata.tensor_properties.dtype = (
104118
param.dtype # pyre-ignore[16]
@@ -133,5 +147,24 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
133147
process_group=pg,
134148
)
135149
)
136-
150+
# DTensor path
151+
for key in key_to_local_tensor_shards:
152+
dtensor_metadata = key_to_dtensor_metadata[key]
153+
destination[key] = DTensor.from_local(
154+
local_tensor=LocalShardsWrapper(
155+
local_shards=[
156+
tensor_shards[0] # pyre-ignore[16]
157+
for tensor_shards in key_to_local_tensor_shards[key]
158+
],
159+
local_offsets=[
160+
tensor_shards[1]
161+
for tensor_shards in key_to_local_tensor_shards[key]
162+
],
163+
),
164+
device_mesh=dtensor_metadata.mesh,
165+
placements=dtensor_metadata.placements,
166+
shape=torch.Size(dtensor_metadata.size), # pyre-ignore[6]
167+
stride=dtensor_metadata.stride,
168+
run_check=False,
169+
)
137170
return destination

0 commit comments

Comments
 (0)