Skip to content

Commit cee743f

Browse files
gchananfacebook-github-bot
authored andcommitted
Move backward/set_data to Type-based dispatch.
Summary: Pull Request resolved: pytorch#11440 Differential Revision: D9736565 Pulled By: gchanan fbshipit-source-id: 1e66f54f1c87084f37c0b014030f0d6d2f8dfaee
1 parent 87a9a8f commit cee743f

File tree

15 files changed

+66
-47
lines changed

15 files changed

+66
-47
lines changed

aten/src/ATen/SparseTensorImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace {
2929
// This means that we allocate a [1,0] size indices tensor and a [0] size
3030
// values tensor for such an empty tensor.
3131
SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, at::ScalarType scalar_type)
32-
: TensorImpl(type_id, scalar_type, false)
32+
: TensorImpl(type_id, scalar_type, nullptr, false)
3333
, size_{0}
3434
, sparseDims_(1)
3535
, denseDims_(0)

aten/src/ATen/TensorImpl.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <ATen/TensorImpl.h>
22

3-
#include <ATen/Type.h>
43
#include <ATen/core/optional.h>
54
#include <ATen/core/Backend.h>
65
#include <ATen/core/WrapDimMinimal.h>
@@ -18,31 +17,12 @@ const Tensor& TensorImpl::grad() const {
1817
AT_ERROR("grad is not implemented for Tensor");
1918
}
2019

21-
void TensorImpl::backward(
22-
at::optional<Tensor> gradient,
23-
bool keep_graph,
24-
bool create_graph) {
25-
AT_ERROR("backward is not implemented for Tensor");
26-
}
27-
28-
void TensorImpl::set_data(Tensor new_data) {
29-
AT_ERROR("set_type is not implemented for Tensor");
30-
}
31-
32-
void Tensor::backward(
33-
at::optional<Tensor> gradient,
34-
bool keep_graph,
35-
bool create_graph) {
36-
tensor_impl_->backward(std::move(gradient), keep_graph, create_graph);
37-
}
38-
39-
TensorImpl::TensorImpl(TensorTypeId type_id, ScalarType scalar_type, bool is_variable)
20+
TensorImpl::TensorImpl(TensorTypeId type_id, ScalarType scalar_type, Allocator *allocator, bool is_variable)
4021
: TensorImpl({}, type_id, scalar_type, is_variable) {
4122
// UndefinedTensors and SparseTensors don't have storages.
4223
if (type_id != UndefinedTensorId() && scalar_type != ScalarType::Undefined
4324
&& type_id != SparseCPUTensorId() && type_id != SparseCUDATensorId()) {
44-
auto type = &globalLegacyTypeDispatch().getNonVariableType(tensorTypeIdToBackend(type_id), scalar_type);
45-
storage_ = type->storage(true);
25+
storage_ = Storage(scalar_type, 0, allocator, true);
4626
}
4727
}
4828

aten/src/ATen/TensorImpl.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct Tensor;
2222
namespace at {
2323
struct AT_API TensorImpl : public c10::intrusive_ptr_target {
2424
TensorImpl() = delete;
25-
TensorImpl(TensorTypeId type_id, ScalarType scalar_type, bool is_variable);
25+
TensorImpl(TensorTypeId type_id, ScalarType scalar_type, Allocator *allocator, bool is_variable);
2626
TensorImpl(Storage&& storage, TensorTypeId type_id, bool is_variable);
2727

2828
virtual void release_resources() override;
@@ -90,13 +90,6 @@ struct AT_API TensorImpl : public c10::intrusive_ptr_target {
9090
virtual Tensor& grad();
9191
virtual const Tensor& grad() const;
9292

93-
virtual void backward(
94-
at::optional<Tensor> gradient,
95-
bool keep_graph,
96-
bool create_graph);
97-
98-
virtual void set_data(Tensor new_data);
99-
10093
// TODO: make these protected
10194
// Note: storage->size() may be greater than the recorded size
10295
// of a tensor

aten/src/ATen/UndefinedTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace at {
55

66
// should this use the globalContext? Can it get a context passed in somehow?
77
UndefinedTensor::UndefinedTensor()
8-
: TensorImpl(UndefinedTensorId(), ScalarType::Undefined, /* is variable */ false) {
8+
: TensorImpl(UndefinedTensorId(), ScalarType::Undefined, nullptr, /* is variable */ false) {
99
}
1010

1111
IntList UndefinedTensor::sizes() const {

aten/src/ATen/function_wrapper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,17 +331,17 @@ def __init__(self, reason):
331331

332332
ALLOC_NOARGS_WRAP = {
333333
'THTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensor>'
334-
'(${Backend}TensorId(), ScalarType::${ScalarName}, false).release()',
334+
'(${Backend}TensorId(), ScalarType::${ScalarName}, allocator(), false).release()',
335335
'THBoolTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensor>'
336-
'(${Backend}TensorId(), ScalarType::Byte, false).release()',
336+
'(${Backend}TensorId(), ScalarType::Byte, allocator(), false).release()',
337337
'THIndexTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensor>'
338-
'(${Backend}TensorId(), ScalarType::Long, false).release()',
338+
'(${Backend}TensorId(), ScalarType::Long, allocator(), false).release()',
339339
'THIntegerTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensor>'
340-
'(${Backend}TensorId(), ScalarType::Int, false).release()',
340+
'(${Backend}TensorId(), ScalarType::Int, allocator(), false).release()',
341341
'THDenseTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensor>'
342-
'(${Backend}TensorId(), ScalarType::${ScalarName}, false).release()',
342+
'(${Backend}TensorId(), ScalarType::${ScalarName}, allocator(), false).release()',
343343
'THDenseIndexTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensor>'
344-
'(${Backend}TensorId(), ScalarType::Long, false).release()'
344+
'(${Backend}TensorId(), ScalarType::Long, allocator(), false).release()'
345345
}
346346

347347
ALLOC_WRAP = {

aten/src/ATen/templates/Tensor.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,7 @@ struct AT_API Tensor {
244244
return tensor_impl_->grad();
245245
}
246246

247-
void set_data(Tensor new_data) {
248-
tensor_impl_->set_data(new_data);
249-
}
247+
void set_data(Tensor new_data);
250248

251249
/// Computes the gradient of current tensor w.r.t. graph leaves.
252250
void backward(

aten/src/ATen/templates/TensorMethods.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ inline Tensor Tensor::to(Device device, bool non_blocking) const {
8383
return detail::to(*this, options().device(device), non_blocking);
8484
}
8585

86+
inline void Tensor::backward(
87+
at::optional<Tensor> gradient,
88+
bool keep_graph,
89+
bool create_graph) {
90+
type().backward(*this, std::move(gradient), keep_graph, create_graph);
91+
}
92+
93+
inline void Tensor::set_data(Tensor new_data) {
94+
type().set_data(*this, new_data);
95+
}
96+
8697
// all static inline to allow for inlining of the non-dynamic part of dispatch
8798
${tensor_method_definitions}
8899

aten/src/ATen/templates/Type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ struct AT_API Type {
104104
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
105105
virtual Tensor & _s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const = 0;
106106

107+
virtual void backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const = 0;
108+
virtual void set_data(Tensor & self, Tensor new_data) const = 0;
109+
107110
virtual Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
108111
virtual Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
109112
virtual Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const = 0;

aten/src/ATen/templates/TypeDefault.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ Tensor TypeDefault::copy(const Tensor & src, bool non_blocking) const {
4040
}
4141
}
4242

43+
void TypeDefault::backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const {
44+
AT_ERROR("backward is not implemented for Tensor");
45+
}
46+
47+
void TypeDefault::set_data(Tensor & self, Tensor new_data) const {
48+
AT_ERROR("set_data is not implemented for Tensor");
49+
}
50+
4351
Type & TypeDefault::toBackend(Backend b) const {
4452
return at::globalContext().getNonVariableType(b,scalarType());
4553
}

aten/src/ATen/templates/TypeDefault.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ struct AT_API TypeDefault : public Type {
2828
Tensor copy(const Tensor & src, bool non_blocking=false) const override;
2929
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const override;
3030

31+
void backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const override;
32+
void set_data(Tensor & self, Tensor new_data) const override;
33+
3134
Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const override;
3235
Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const override;
3336
Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const override;

tools/autograd/templates/VariableType.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,13 @@ static bool isFloatingPoint(ScalarType s) {
378378
return s == kFloat || s == kDouble || s == kHalf;
379379
}
380380

381+
void VariableType::backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const {
382+
as_variable_ref(self).backward(gradient, keep_graph, create_graph);
383+
}
384+
385+
void VariableType::set_data(Tensor & self, Tensor new_data) const {
386+
as_variable_ref(self).set_data(new_data);
387+
}
381388
Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
382389
jit::Node* node = nullptr;
383390
if(torch::jit::tracer::isTracing()) {

tools/autograd/templates/VariableType.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ struct TORCH_API VariableType final : public at::TypeDefault {
5858

5959
Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const override;
6060
Tensor & _s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const override;
61+
62+
void backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const override;
63+
void set_data(Tensor & self, Tensor new_data) const override;
64+
6165
${type_derived_method_declarations}
6266

6367
private:

torch/csrc/autograd/variable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
namespace torch {
2323
namespace autograd {
2424
Variable::Impl::Impl(at::Tensor data, bool requires_grad, Edge gradient_edge)
25-
: TensorImpl(data.type().type_id(), data.type().scalarType(), /* is variable */ true),
25+
: TensorImpl(data.type().type_id(), data.type().scalarType(), data.type().allocator(), /* is variable */ true),
2626
data_(std::move(data)),
2727
grad_fn_(std::move(gradient_edge.function)),
2828
requires_grad_(false),

torch/csrc/autograd/variable.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,12 @@ struct TORCH_API Variable : public at::Tensor {
187187
/// this. If this `Variable` is a view, throws an `std::runtime_error()`.
188188
void detach_();
189189

190+
/// Computes the gradient of current tensor w.r.t. graph leaves.
191+
void backward(at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const;
192+
193+
/// Sets the type of the Variable.
194+
void set_data(Tensor new_data) const;
195+
190196
/// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
191197
/// `Variable`.
192198
/// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
@@ -324,14 +330,12 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
324330
Variable detach() const;
325331
void detach_();
326332

327-
/// Sets the type of the Variable.
328-
void set_data(Tensor new_data) override;
333+
void set_data(Tensor new_data);
329334

330-
/// Computes the gradient of current tensor w.r.t. graph leaves.
331335
void backward(
332336
at::optional<at::Tensor> gradient,
333337
bool keep_graph,
334-
bool create_graph) override;
338+
bool create_graph);
335339

336340
/// Reset all expensive fields to free up resources
337341
void release_resources() override;
@@ -500,6 +504,14 @@ inline void Variable::detach_() {
500504
get()->detach_();
501505
}
502506

507+
inline void Variable::backward(at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const {
508+
get()->backward(gradient, keep_graph, create_graph);
509+
}
510+
511+
inline void Variable::set_data(Tensor new_data) const {
512+
get()->set_data(new_data);
513+
}
514+
503515
inline void Variable::set_gradient_edge(Edge edge) noexcept {
504516
get()->grad_fn_ = std::move(edge.function);
505517
get()->output_nr_ = edge.input_nr;

torch/csrc/jit/interpreter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ struct PreprocessGraph {
337337
struct ContainerTensor : public at::TensorImpl {
338338
public:
339339
ContainerTensor()
340-
: TensorImpl(at::UndefinedTensorId(), at::ScalarType::Undefined, /* is_variable */ false) {}
340+
: TensorImpl(at::UndefinedTensorId(), at::ScalarType::Undefined, nullptr, /* is_variable */ false) {}
341341

342342
virtual ~ContainerTensor() = default;
343343
virtual at::IntList sizes() const override {

0 commit comments

Comments
 (0)