Skip to content

Remove .type().tensor() #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchvision/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
auto height = input.size(2);
auto width = input.size(3);

at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});

at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
auto output_size = num_rois * pooled_height * pooled_width * channels;

if (output.numel() == 0)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/cpu/ROIPool_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(const at::Tensor &input,
int input_height = input.size(2);
int input_width = input.size(3);

at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});
at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_();
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt));

// define accessors for indexing
auto input_a = input.accessor<float, 4>();
Expand Down
7 changes: 3 additions & 4 deletions torchvision/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "cpu/vision.h"


template <typename scalar_t>
at::Tensor nms_cpu_kernel(const at::Tensor& dets,
const at::Tensor& scores,
Expand All @@ -10,7 +9,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets,
AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores");

if (dets.numel() == 0)
return torch::CPU(at::kLong).tensor();
return at::empty({0}, at::device(at::kCPU).dtype(at::kLong));

auto x1_t = dets.select(1, 0).contiguous();
auto y1_t = dets.select(1, 1).contiguous();
Expand All @@ -22,7 +21,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets,
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));

auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros(torch::CPU(at::kByte), {ndets});
at::Tensor suppressed_t = at::zeros({ndets}, at::device(at::kCPU).dtype(at::kByte));

auto suppressed = suppressed_t.data<uint8_t>();
auto order = order_t.data<int64_t>();
Expand Down Expand Up @@ -66,7 +65,7 @@ at::Tensor nms_cpu(const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {

auto result = dets.type().tensor();
auto result = at::empty({0}, dets.type());

AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
auto height = input.size(2);
auto width = input.size(3);

at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());

auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -313,7 +313,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");

auto num_rois = rois.size(0);
at::Tensor grad_input = grad.type().tensor({batch_size, channels, height, width}).zero_();
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
auto height = input.size(2);
auto width = input.size(3);

at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width});
at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_();
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt));

auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down