Skip to content

Commit 7711bf3

Browse files
kihyuksfacebook-github-bot
authored andcommitted
fix device error
Summary: When using `sample_farthest_points` with `lengths`, it throws an error because of the device mismatch between `lengths` and `torch.rand(lengths.size())` on GPU. Reviewed By: bottler Differential Revision: D82378997 fbshipit-source-id: 8e929256177d543d1dd1249e8488f70e03e4101f
1 parent d098beb commit 7711bf3

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytorch3d/ops/sample_farthest_points.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def sample_farthest_points(
8989
if constant_length:
9090
start_idxs = torch.randint(high=P, size=(N,), device=device)
9191
else:
92-
start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64)
92+
start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to(
93+
torch.int64
94+
)
9395
else:
9496
start_idxs = torch.zeros_like(lengths)
9597

0 commit comments

Comments
 (0)