Skip to content

Commit ef46d2f

Browse files
ezyangRob Kunkle
authored and
Rob Kunkle
committed
Delete type_ field from TensorImpl, replaced with backend_/scalar_typ… (pytorch#9787)
Summary: …e_/is_variable_ The basic game plan is to stop accessing the type_ field directly, and instead using the stored backend_, scalar_type_ and is_variable_ to look up the appropriate Type from Context. Storage of backend_ and scalar_type_ are new. At some future point in time, I'd like to look at this code carefully to see if I can get everything in this codepath inlining. I didn't do it in this patch because there are circular include problems making things difficult. Some other details: - Added Device::backend() which does what it says on the tin - SparseTensorImpl is temporarily hard-coded to root in at::Context for the appropriate context. If/when we put this in shared code, we'll have to break this dep too, but for now it should be OK. - There's a stupid problem with globalContext() deadlocking if you didn't actually initialize it before loading libtorch.so (which is bringing along the variable hooks). I didn't fix it in this PR; it's tracked in pytorch#9784 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#9787 Reviewed By: cpuhrsch Differential Revision: D8980971 Pulled By: ezyang fbshipit-source-id: 2b4d867abfdc3999a836a220c638c109053145a8
1 parent 1968621 commit ef46d2f

14 files changed

+57
-22
lines changed

aten/src/ATen/Context.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ Context::Context()
3737
Type::registerCPU(this);
3838
}
3939

40+
// NB: Ensure that globalContext is initialized before we load
41+
// variable hooks, otherwise we will deadlock. Regardless, the
42+
// deadlock is bad, and being tracked at https://github.com/pytorch/pytorch/issues/9784
43+
static Context globalContext_;
4044
Context & globalContext() {
41-
static Context globalContext_;
4245
return globalContext_;
4346
}
4447

aten/src/ATen/SparseTensorImpl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ namespace at {
1818
// tensor and a [0] size values tensor for such an empty tensor. However,
1919
// we don't currently support zero-size dimensions, so we can't actually
2020
// do this; so we just allocate zero-size tensors for everything.
21-
SparseTensorImpl::SparseTensorImpl(Type * type)
22-
: TensorImpl(type, nullptr)
21+
SparseTensorImpl::SparseTensorImpl(at::Backend backend, at::ScalarType scalar_type)
22+
: TensorImpl(backend, scalar_type, nullptr, false)
2323
, size_{0}
2424
, sparseDims_(1)
2525
, denseDims_(0)
26-
, indices_(type->toDense().toScalarType(ScalarType::Long).tensor())
27-
, values_(type->toDense().tensor()) {
28-
AT_ASSERT(type->is_sparse());
26+
, indices_(globalContext().getTypeOpt(toDense(backend), ScalarType::Long)->tensor())
27+
, values_(globalContext().getTypeOpt(toDense(backend), scalar_type)->tensor()) {
28+
AT_ASSERT(backend == Backend::SparseCPU || backend == Backend::SparseCUDA);
2929
}
3030

3131
IntList SparseTensorImpl::sizes() const {

aten/src/ATen/SparseTensorImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct AT_API SparseTensorImpl : public TensorImpl {
4848

4949
public:
5050
// Public for now...
51-
explicit SparseTensorImpl(Type * type);
51+
explicit SparseTensorImpl(at::Backend, at::ScalarType);
5252

5353
int64_t nnz() const { return nnz_; }
5454
int64_t sparseDims() const { return sparseDims_; }

aten/src/ATen/TensorImpl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,23 @@
22

33
#include <ATen/Tensor.h>
44
#include <ATen/optional.h>
5+
#include <ATen/Context.h>
6+
7+
#include <ATen/detail/VariableHooksInterface.h>
58

69
#include <TH/THTensor.hpp>
710

811
namespace at {
12+
13+
Type& TensorImpl::type() const {
14+
Type* base_type = &globalContext().getType(backend_, scalar_type_);
15+
if (is_variable_) {
16+
return detail::getVariableHooks().getVariableType(*base_type);
17+
} else {
18+
return *base_type;
19+
}
20+
}
21+
922
Tensor& TensorImpl::grad() {
1023
AT_ERROR("grad is not implemented for Tensor");
1124
}

aten/src/ATen/TensorImpl.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ struct Tensor;
1818

1919
namespace at {
2020
struct AT_API TensorImpl : public Retainable {
21-
explicit TensorImpl(Type * type, THTensor * tensor)
22-
: type_(type), tensor(tensor) {}
21+
explicit TensorImpl(Backend backend, ScalarType scalar_type, THTensor * tensor, bool is_variable)
22+
: backend_(backend), scalar_type_(scalar_type), is_variable_(is_variable), tensor(tensor) {}
2323

2424
virtual ~TensorImpl();
2525

2626
virtual void release_resources() override;
2727

28-
Type & type() const {
29-
return *type_;
30-
}
28+
// The implementation of this method will have to be hoisted out and
29+
// hooked in, so that Caffe2 doesn't need to know about Context
30+
// TODO: This really really needs to be inlined.
31+
Type & type() const;
32+
3133
const char * toString() const;
3234
virtual IntList sizes() const;
3335
virtual IntList strides() const;
@@ -91,8 +93,12 @@ struct AT_API TensorImpl : public Retainable {
9193
virtual void set_data(Tensor new_data);
9294

9395
protected:
96+
Backend backend_;
97+
// INVARIANT: When storage is non-null, this scalar type must
98+
// agree with the scalar type in storage
99+
ScalarType scalar_type_;
100+
bool is_variable_ = false;
94101
bool is_wrapped_number_ = false;
95-
Type * type_;
96102
public:
97103
THTensor * tensor;
98104
};

aten/src/ATen/UndefinedTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace at {
66

77
// should this use the globalContext? Can it get a context passed in somehow?
88
UndefinedTensor::UndefinedTensor()
9-
: TensorImpl(&(globalContext().getType(Backend::Undefined,ScalarType::Undefined)), nullptr) {
9+
: TensorImpl(Backend::Undefined, ScalarType::Undefined, nullptr, /* is variable */ false) {
1010
}
1111

1212
IntList UndefinedTensor::sizes() const {

aten/src/ATen/detail/VariableHooksInterface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/Registry.h>
44
#include <ATen/Error.h>
55
#include <ATen/ScalarType.h>
6+
#include <ATen/Type.h>
67

78
namespace at {
89
class Context;
@@ -25,6 +26,10 @@ struct AT_API VariableHooksInterface {
2526
// squelch -Werror=non-virtual-dtor
2627
virtual ~VariableHooksInterface() {}
2728

29+
virtual Type& getVariableType(const at::Type& baseType) const {
30+
AT_ERROR("cannot getVariableType without libtorch");
31+
}
32+
2833
virtual void registerVariableTypeFor(Context*, Backend backend, ScalarType scalar_type) const {
2934
// no-op if Variable not available; it'll get handled (if at all) when
3035
// libtorch.so gets loaded

aten/src/ATen/native/sparse/SparseTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ SparseTensor new_sparse(const SparseType& dtype) {
6363
AT_ASSERT(!dtype.is_variable());
6464
AT_ASSERT(dtype.is_sparse());
6565
// TODO: Hmm... this const_cast business seems a bit dodgy
66-
return SparseTensor(new SparseTensorImpl(const_cast<SparseType*>(&dtype)), /* retain */ false);
66+
return SparseTensor(new SparseTensorImpl(dtype.backend(), dtype.scalarType()), /* retain */ false);
6767
}
6868

6969
/*** Helper methods ***/

aten/src/ATen/templates/TensorDerived.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace detail {
2121
}
2222

2323
${Tensor}::${Tensor}(${THTensor} * tensor)
24-
: TensorImpl(&globalContext().getType(Backend::${Backend},ScalarType::${ScalarName}), tensor)
24+
: TensorImpl(Backend::${Backend}, ScalarType::${ScalarName}, tensor, /* is variable */ false)
2525
{}
2626

2727
${TensorDenseOrSparse}

aten/src/TH/THTensor.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ struct THTensor
5656
return sizes_.size();
5757
}
5858

59+
at::ScalarType scalar_type() const {
60+
return storage_->scalar_type;
61+
}
62+
5963
ptrdiff_t storage_offset() const {
6064
return storage_offset_;
6165
}

torch/csrc/autograd/aten_variable_hooks.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace torch { namespace autograd {
66
struct VariableHooks : public at::VariableHooksInterface {
77
VariableHooks(at::VariableHooksArgs) {}
88
void registerVariableTypeFor(at::Context*, at::Backend, at::ScalarType) const override;
9+
at::Type& getVariableType(const at::Type&) const override;
910
};
1011

1112
// Sigh, the registry doesn't support namespaces :(
@@ -20,4 +21,8 @@ void VariableHooks::registerVariableTypeFor(at::Context* context, at::Backend ba
2021
register_variable_type_for(baseType);
2122
}
2223

24+
at::Type& VariableHooks::getVariableType(const at::Type& baseType) const {
25+
return *VariableType::getType(baseType);
26+
}
27+
2328
}} // torch::autograd

torch/csrc/autograd/variable.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
namespace torch {
2323
namespace autograd {
2424
Variable::Impl::Impl(at::Tensor data, bool requires_grad, Edge gradient_edge)
25-
: TensorImpl(VariableType::getType(data), nullptr),
25+
: TensorImpl(data.type().backend(), data.type().scalarType(), nullptr, /* is variable */ true),
2626
data_(std::move(data)),
2727
grad_fn_(std::move(gradient_edge.function)),
2828
requires_grad_(false),
@@ -118,7 +118,9 @@ void Variable::Impl::backward(
118118

119119
void Variable::Impl::set_data(Tensor new_data) {
120120
if (new_data.type() != data_.type()) {
121-
type_ = VariableType::getType(new_data.type());
121+
scalar_type_ = new_data.type().scalarType();
122+
backend_ = new_data.type().backend();
123+
is_variable_ = true;
122124
// Clear grad_accumulator if it exists, since it stores the old type info.
123125
grad_accumulator_.reset();
124126
}

torch/csrc/autograd/variable.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,6 @@ struct Variable::Impl : public at::TensorImpl {
327327
/// Reset all expensive fields to free up resources
328328
void release_resources() override;
329329

330-
// Make this field public so we can access it from `Variable`.
331-
using at::TensorImpl::type_;
332-
333330
std::string name;
334331
at::Tensor data_;
335332

torch/csrc/jit/interpreter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ struct PreprocessGraph {
337337
struct ContainerTensor : public at::TensorImpl {
338338
public:
339339
ContainerTensor()
340-
: TensorImpl(&(at::globalContext().getType(at::Backend::Undefined,at::ScalarType::Undefined)), nullptr) {}
340+
: TensorImpl(at::Backend::Undefined,at::ScalarType::Undefined, nullptr, /* is_variable */ false) {}
341341

342342
virtual ~ContainerTensor() = default;
343343
virtual at::IntList sizes() const override {

0 commit comments

Comments
 (0)