Skip to content

Commit b51ac15

Browse files
committed
remove .type().tensor() calls in favor of the new approach to tensor initialization
1 parent ede93b2 commit b51ac15

File tree

5 files changed

+11
-12
lines changed

5 files changed

+11
-12
lines changed

torchvision/csrc/cpu/ROIAlign_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
231231
auto height = input.size(2);
232232
auto width = input.size(3);
233233

234-
at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});
235-
234+
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
235+
236236
auto output_size = num_rois * pooled_height * pooled_width * channels;
237237

238238
if (output.numel() == 0)

torchvision/csrc/cpu/ROIPool_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(const at::Tensor &input,
1616
int input_height = input.size(2);
1717
int input_width = input.size(3);
1818

19-
at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});
20-
at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_();
19+
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
20+
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt));
2121

2222
// define accessors for indexing
2323
auto input_a = input.accessor<float, 4>();

torchvision/csrc/cpu/nms_cpu.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "cpu/vision.h"
22

3-
43
template <typename scalar_t>
54
at::Tensor nms_cpu_kernel(const at::Tensor& dets,
65
const at::Tensor& scores,
@@ -10,7 +9,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets,
109
AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores");
1110

1211
if (dets.numel() == 0)
13-
return torch::CPU(at::kLong).tensor();
12+
return at::empty({0}, at::device(at::kCPU).dtype(at::kLong));
1413

1514
auto x1_t = dets.select(1, 0).contiguous();
1615
auto y1_t = dets.select(1, 1).contiguous();
@@ -22,7 +21,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets,
2221
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
2322

2423
auto ndets = dets.size(0);
25-
at::Tensor suppressed_t = at::zeros(torch::CPU(at::kByte), {ndets});
24+
at::Tensor suppressed_t = at::zeros({ndets}, at::device(at::kCPU).dtype(at::kByte));
2625

2726
auto suppressed = suppressed_t.data<uint8_t>();
2827
auto order = order_t.data<int64_t>();
@@ -66,7 +65,7 @@ at::Tensor nms_cpu(const at::Tensor& dets,
6665
const at::Tensor& scores,
6766
const float threshold) {
6867

69-
auto result = dets.type().tensor();
68+
auto result = at::empty({0}, dets.type());
7069

7170
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
7271
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);

torchvision/csrc/cuda/ROIAlign_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
267267
auto height = input.size(2);
268268
auto width = input.size(3);
269269

270-
at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});
270+
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
271271

272272
auto output_size = num_rois * pooled_height * pooled_width * channels;
273273
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -313,7 +313,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
313313
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
314314

315315
auto num_rois = rois.size(0);
316-
at::Tensor grad_input = grad.type().tensor({batch_size, channels, height, width}).zero_();
316+
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
317317

318318
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
319319

torchvision/csrc/cuda/ROIPool_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
116116
auto height = input.size(2);
117117
auto width = input.size(3);
118118

119-
at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});
120-
at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_();
119+
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
120+
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt));
121121

122122
auto output_size = num_rois * pooled_height * pooled_width * channels;
123123
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

0 commit comments

Comments
 (0)