diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 06ecc551442..0057da45e24 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -97,10 +97,10 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) # For output anchor, compute [x_center, y_center, x_center, y_center] shifts_x = torch.arange( - 0, grid_width, dtype=torch.float32, device=device + 0, grid_width, dtype=torch.int32, device=device ) * stride_width shifts_y = torch.arange( - 0, grid_height, dtype=torch.float32, device=device + 0, grid_height, dtype=torch.int32, device=device ) * stride_height shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1)