Skip to content

Commit 96feeb1

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Fixing device issues (#1880)
Summary: Pull Request resolved: #1880 Fixing device issues with sharding TorchRec inference modules on meta device Reviewed By: IvanKobzarev Differential Revision: D56020810 fbshipit-source-id: c14298130e702c1cbd67f1660f341178c15edbb4
1 parent 10c07a9 commit 96feeb1

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,12 @@ def __init__(
917917
MetaInferGroupedPooledEmbeddingsLookup(
918918
grouped_configs=grouped_configs_per_rank[rank],
919919
# syntax for torchscript
920-
device=torch.device(type=device_type, index=rank),
920+
# No rank for cpu
921+
device=(
922+
torch.device(type=device_type, index=rank)
923+
if device_type != "cpu"
924+
else torch.device(device_type)
925+
),
921926
fused_params=fused_params,
922927
)
923928
)

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _unwrap_kjt(
163163
def _unwrap_kjt_for_cpu(
164164
features: KeyedJaggedTensor,
165165
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
166-
assert features.device().type == "cpu"
166+
assert features.device().type == "cpu" or features.device().type == "meta"
167167
return features.values(), features.offsets(), features.weights_or_none()
168168

169169

0 commit comments

Comments
 (0)