File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed
torch/distributed/fsdp/_fully_shard Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -376,9 +376,7 @@ def _init_sharded_param(
376
376
if self .offload_to_cpu and not padded_sharded_param .is_meta :
377
377
padded_sharded_param = padded_sharded_param .cpu ()
378
378
if self .pin_memory :
379
- padded_sharded_param = padded_sharded_param .pin_memory (
380
- device = self .device
381
- )
379
+ padded_sharded_param = padded_sharded_param .pin_memory ()
382
380
self ._sharded_param_data = padded_sharded_param .view (- 1 )
383
381
length = sharded_param .size (shard_dim ) if sharded_param .numel () > 0 else 0
384
382
sharded_param = padded_sharded_param .narrow (
@@ -848,7 +846,7 @@ def reset_sharded_param(self):
848
846
local_tensor = padded_local_tensor
849
847
updated_local_tensor = True
850
848
if self .pin_memory and not local_tensor .is_pinned ():
851
- local_tensor = local_tensor .cpu ().pin_memory (device = self . device )
849
+ local_tensor = local_tensor .cpu ().pin_memory ()
852
850
updated_local_tensor = True
853
851
self ._sharded_param_data = local_tensor .view (- 1 )
854
852
assert isinstance (self .sharded_param , DTensor ) # mypy
You can’t perform that action at this time.
0 commit comments