Skip to content

Commit 0629dfb

Browse files
Edenzzzzpytorchmergebot
authored andcommitted
Fix FSDP offload pin_memory bug (pytorch#157147)
Fixes pytorch#157146 Pull Request resolved: pytorch#157147 Approved by: https://github.com/weifengpy
1 parent 67f8270 commit 0629dfb

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torch/distributed/fsdp/_fully_shard/_fsdp_param.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,7 @@ def _init_sharded_param(
376376
if self.offload_to_cpu and not padded_sharded_param.is_meta:
377377
padded_sharded_param = padded_sharded_param.cpu()
378378
if self.pin_memory:
379-
padded_sharded_param = padded_sharded_param.pin_memory(
380-
device=self.device
381-
)
379+
padded_sharded_param = padded_sharded_param.pin_memory()
382380
self._sharded_param_data = padded_sharded_param.view(-1)
383381
length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0
384382
sharded_param = padded_sharded_param.narrow(
@@ -848,7 +846,7 @@ def reset_sharded_param(self):
848846
local_tensor = padded_local_tensor
849847
updated_local_tensor = True
850848
if self.pin_memory and not local_tensor.is_pinned():
851-
local_tensor = local_tensor.cpu().pin_memory(device=self.device)
849+
local_tensor = local_tensor.cpu().pin_memory()
852850
updated_local_tensor = True
853851
self._sharded_param_data = local_tensor.view(-1)
854852
assert isinstance(self.sharded_param, DTensor) # mypy

0 commit comments

Comments
 (0)