Skip to content

Merge from upstream #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions aten/src/ATen/Backend.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#pragma once

#include <ATen/core/TensorTypeId.h>
#include <ATen/core/TensorTypeIdRegistration.h>
#include <ATen/core/Error.h>

#include <stdexcept>

namespace at {
Expand Down Expand Up @@ -40,6 +45,39 @@ static inline Backend toDense(Backend b) {
}
}

static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
if (t == CPUTensorId()) {
return Backend::CPU;
} else if (t == CUDATensorId()) {
return Backend::CUDA;
} else if (t == SparseCPUTensorId()) {
return Backend::SparseCPU;
} else if (t == SparseCUDATensorId()) {
return Backend::SparseCUDA;
} else if (t == UndefinedTensorId()) {
return Backend::Undefined;
} else {
AT_ERROR("Unrecognized tensor type ID: ", t);
}
}

static inline TensorTypeId backendToTensorTypeId(Backend b) {
switch (b) {
case Backend::CPU:
return CPUTensorId();
case Backend::CUDA:
return CUDATensorId();
case Backend::SparseCPU:
return SparseCPUTensorId();
case Backend::SparseCUDA:
return SparseCUDATensorId();
case Backend::Undefined:
return UndefinedTensorId();
default:
throw std::runtime_error("Unknown backend");
}
}

static inline const char* toString(Backend b) {
switch (b) {
case Backend::CPU:
Expand Down
22 changes: 16 additions & 6 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@

namespace at {

namespace {
Backend sparseTensorIdToDenseBackend(TensorTypeId type_id) {
if (type_id == SparseCPUTensorId()) {
return Backend::CPU;
} else if (type_id == SparseCUDATensorId()) {
return Backend::CUDA;
} else {
AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_id);
}
}
}


// An empty dense tensor defaults to a 1-dimensional tensor of size [0]
// (recall, it is not a 0-dimensional tensor, because such a tensor would
Expand All @@ -18,15 +30,13 @@ namespace at {
// tensor and a [0] size values tensor for such an empty tensor. However,
// we don't currently support zero-size dimensions, so we can't actually
// do this; so we just allocate zero-size tensors for everything.
SparseTensorImpl::SparseTensorImpl(at::Backend backend, at::ScalarType scalar_type)
: TensorImpl(backend, scalar_type, nullptr, false)
SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, at::ScalarType scalar_type)
: TensorImpl(type_id, scalar_type, nullptr, false)
, size_{0}
, sparseDims_(1)
, denseDims_(0)
, indices_(globalContext().getTypeOpt(toDense(backend), ScalarType::Long)->tensor())
, values_(globalContext().getTypeOpt(toDense(backend), scalar_type)->tensor()) {
AT_ASSERT(backend == Backend::SparseCPU || backend == Backend::SparseCUDA);
}
, indices_(globalContext().getTypeOpt(sparseTensorIdToDenseBackend(type_id), ScalarType::Long)->tensor())
, values_(globalContext().getTypeOpt(sparseTensorIdToDenseBackend(type_id), scalar_type)->tensor()) {}

IntList SparseTensorImpl::sizes() const {
return size_;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/SparseTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct AT_API SparseTensorImpl : public TensorImpl {

public:
// Public for now...
explicit SparseTensorImpl(at::Backend, at::ScalarType);
explicit SparseTensorImpl(at::TensorTypeId, at::ScalarType);

int64_t nnz() const { return nnz_; }
int64_t sparseDims() const { return sparseDims_; }
Expand Down
13 changes: 8 additions & 5 deletions aten/src/ATen/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/Tensor.h>
#include <ATen/core/optional.h>
#include <ATen/Context.h>
#include <ATen/Backend.h>

#include <ATen/detail/VariableHooksInterface.h>

Expand All @@ -12,7 +13,10 @@
namespace at {

Type& TensorImpl::type() const {
Type* base_type = &globalContext().getType(backend_, scalar_type_);
// Select backend from the hard-coded ones that the legacy ATen dispatcher
// knows about
Backend backend = tensorTypeIdToBackend(type_id_);
Type* base_type = &globalContext().getType(backend, scalar_type_);
if (is_variable_) {
return detail::getVariableHooks().getVariableType(*base_type);
} else {
Expand Down Expand Up @@ -55,10 +59,9 @@ void Tensor::backward(
pImpl->backward(std::move(gradient), keep_graph, create_graph);
}

TensorImpl::TensorImpl(Backend backend, ScalarType scalar_type) {
backend_ = backend;
scalar_type_ = scalar_type;
auto type = &globalContext().getType(backend, scalar_type);
TensorImpl::TensorImpl(TensorTypeId type_id, ScalarType scalar_type)
: type_id_(type_id), scalar_type_(scalar_type) {
auto type = &globalContext().getType(tensorTypeIdToBackend(type_id), scalar_type);
Storage* storage = type->storage(true).release();
StorageImpl* storage_impl = storage->pImpl();
tensor = new THTensor(storage_impl);
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "ATen/Retainable.h"
#include "ATen/ScalarType.h"
#include "ATen/core/optional.h"
#include "ATen/core/TensorTypeId.h"
#include "ATen/core/TensorTypeIdRegistration.h"

struct THTensor;

Expand All @@ -18,9 +20,9 @@ struct Tensor;

namespace at {
struct AT_API TensorImpl : public Retainable {
explicit TensorImpl(Backend backend, ScalarType scalar_type, THTensor * tensor, bool is_variable)
: backend_(backend), scalar_type_(scalar_type), is_variable_(is_variable), tensor(tensor) {}
TensorImpl(Backend backend, ScalarType scalar_type);
explicit TensorImpl(TensorTypeId type_id, ScalarType scalar_type, THTensor * tensor, bool is_variable)
: type_id_(type_id), scalar_type_(scalar_type), is_variable_(is_variable), tensor(tensor) {}
TensorImpl(TensorTypeId type_id, ScalarType scalar_type);

virtual ~TensorImpl();

Expand Down Expand Up @@ -94,7 +96,7 @@ struct AT_API TensorImpl : public Retainable {
virtual void set_data(Tensor new_data);

protected:
Backend backend_;
TensorTypeId type_id_;
// INVARIANT: When storage is non-null, this scalar type must
// agree with the scalar type in storage
ScalarType scalar_type_;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/UndefinedTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace at {

// should this use the globalContext? Can it get a context passed in somehow?
UndefinedTensor::UndefinedTensor()
: TensorImpl(Backend::Undefined, ScalarType::Undefined, nullptr, /* is variable */ false) {
: TensorImpl(UndefinedTensorId(), ScalarType::Undefined, nullptr, /* is variable */ false) {
}

IntList UndefinedTensor::sizes() const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/UndefinedType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace at {

UndefinedType::UndefinedType(Context* context)
: Type(context, /*is_variable=*/false, /*is_undefined=*/true) {}
: Type(context, UndefinedTensorId(), /*is_variable=*/false, /*is_undefined=*/true) {}
ScalarType UndefinedType::scalarType() const {
return ScalarType::Undefined;
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/DeviceType.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

// This is directly synchronized with caffe2/proto/caffe2.proto, but
// doesn't require me to figure out how to get Protobuf headers into
// ATen/core (which would require a lot more build system hacking.)
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/core/IdWrapper.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <functional>
#include <ATen/core/Macros.h>

namespace at {

Expand All @@ -21,7 +22,7 @@ namespace at {
* for you, given the underlying type supports it.
*/
template <class ConcreteType, class UnderlyingType>
class IdWrapper {
class AT_CORE_API IdWrapper {
public:
using underlying_type = UnderlyingType;
using concrete_type = ConcreteType;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/TensorTypeId.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "ATen/core/TensorTypeId.h"

namespace at {

std::ostream& operator<<(std::ostream& str, at::TensorTypeId rhs) {
return str << rhs.underlyingId();
}

} // namespace at
13 changes: 5 additions & 8 deletions aten/src/ATen/core/TensorTypeId.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
#include <string>
#include <unordered_set>
#include "ATen/core/IdWrapper.h"

namespace at {
class TensorTypeId;
}

std::ostream& operator<<(std::ostream&, at::TensorTypeId);
#include "ATen/core/Macros.h"

namespace at {

Expand All @@ -22,7 +17,7 @@ using _tensorTypeId_underlyingType = uint8_t;
* Dynamic type ID of a Tensor argument. It represents something like
* CPUTensor, etc.
*/
class TensorTypeId final
class AT_CORE_API TensorTypeId final
: public at::
IdWrapper<TensorTypeId, details::_tensorTypeId_underlyingType> {
public:
Expand All @@ -37,9 +32,11 @@ class TensorTypeId final
: IdWrapper(id) {}

friend class TensorTypeIdCreator;
friend std::ostream& ::operator<<(std::ostream&, TensorTypeId);
friend AT_CORE_API std::ostream& operator<<(std::ostream&, TensorTypeId);
};

AT_CORE_API std::ostream& operator<<(std::ostream&, at::TensorTypeId);

} // namespace at

AT_DEFINE_HASH_FOR_IDWRAPPER(at::TensorTypeId)
11 changes: 8 additions & 3 deletions aten/src/ATen/core/TensorTypeIdRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

namespace at {

constexpr at::TensorTypeId TensorTypeIdCreator::max_id_;

TensorTypeIds::TensorTypeIds() : creator_(), registry_() {}

TensorTypeIds& TensorTypeIds::singleton() {
Expand All @@ -16,9 +14,10 @@ TensorTypeIds& TensorTypeIds::singleton() {
TensorTypeIdCreator::TensorTypeIdCreator() : last_id_(0) {}

at::TensorTypeId TensorTypeIdCreator::create() {

auto id = TensorTypeId(++last_id_);

if (id == max_id_) {
if (last_id_ == 0) { // overflow happened!
// If this happens in prod, we have to change
// details::_tensorTypeId_underlyingType to uint16_t.
AT_ERROR(
Expand Down Expand Up @@ -59,4 +58,10 @@ TensorTypeIdRegistrar::~TensorTypeIdRegistrar() {
TensorTypeIds::singleton().deregister(id_);
}

AT_DEFINE_TENSOR_TYPE(UndefinedTensorId);
AT_DEFINE_TENSOR_TYPE(CPUTensorId);
AT_DEFINE_TENSOR_TYPE(CUDATensorId);
AT_DEFINE_TENSOR_TYPE(SparseCPUTensorId);
AT_DEFINE_TENSOR_TYPE(SparseCUDATensorId);

} // namespace at
23 changes: 13 additions & 10 deletions aten/src/ATen/core/TensorTypeIdRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace at {

class TensorTypeIdCreator final {
class AT_CORE_API TensorTypeIdCreator final {
public:
TensorTypeIdCreator();

Expand All @@ -29,13 +29,10 @@ class TensorTypeIdCreator final {
private:
std::atomic<details::_tensorTypeId_underlyingType> last_id_;

static constexpr at::TensorTypeId max_id_ = TensorTypeId(
std::numeric_limits<details::_tensorTypeId_underlyingType>::max());

AT_DISABLE_COPY_AND_ASSIGN(TensorTypeIdCreator);
};

class TensorTypeIdRegistry final {
class AT_CORE_API TensorTypeIdRegistry final {
public:
TensorTypeIdRegistry();

Expand All @@ -49,7 +46,7 @@ class TensorTypeIdRegistry final {
AT_DISABLE_COPY_AND_ASSIGN(TensorTypeIdRegistry);
};

class TensorTypeIds final {
class AT_CORE_API TensorTypeIds final {
public:
static TensorTypeIds& singleton();

Expand All @@ -71,7 +68,7 @@ inline constexpr at::TensorTypeId TensorTypeIds::undefined() noexcept {
return TensorTypeIdCreator::undefined();
}

class TensorTypeIdRegistrar final {
class AT_CORE_API TensorTypeIdRegistrar final {
public:
TensorTypeIdRegistrar();
~TensorTypeIdRegistrar();
Expand All @@ -88,12 +85,18 @@ inline at::TensorTypeId TensorTypeIdRegistrar::id() const noexcept {
return id_;
}

} // namespace at

#define AT_DECLARE_TENSOR_TYPE(TensorName) at::TensorTypeId TensorName();
#define AT_DECLARE_TENSOR_TYPE(TensorName) AT_CORE_API at::TensorTypeId TensorName();

#define AT_DEFINE_TENSOR_TYPE(TensorName) \
at::TensorTypeId TensorName() { \
static TensorTypeIdRegistrar registration_raii; \
return registration_raii.id(); \
}

AT_DECLARE_TENSOR_TYPE(UndefinedTensorId);
AT_DECLARE_TENSOR_TYPE(CPUTensorId); // Caffe2 supported
AT_DECLARE_TENSOR_TYPE(CUDATensorId); // Caffe2 supported
AT_DECLARE_TENSOR_TYPE(SparseCPUTensorId);
AT_DECLARE_TENSOR_TYPE(SparseCUDATensorId);

} // namespace at
Loading