Skip to content

Merge from upstream, including our last PR #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .jenkins/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pip install -r requirements.txt || true
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
# This is necessary in order to cross compile (or else we'll have missing GPU device).
export MAX_JOBS=4
# This is necessary in order to cross compile (or else we'll have missing GPU device).
export HCC_AMDGPU_TARGET=gfx900

# These environment variables are not set on CI when we were running as the Jenkins user.
Expand All @@ -42,7 +43,6 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
# This environment variable enabled HCC Optimizations that speed up the linking stage.
# https://github.com/RadeonOpenCompute/hcc#hcc-with-thinlto-linking
export KMTHINLTO=1

python tools/amd_build/build_pytorch_amd.py
USE_ROCM=1 python setup.py install --user
exit 0
Expand Down
5 changes: 1 addition & 4 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@ Context::Context()
Type::registerCPU(this);
}

// NB: Ensure that globalContext is initialized before we load
// variable hooks, otherwise we will deadlock. Regardless, the
// deadlock is bad, and being tracked at https://github.com/pytorch/pytorch/issues/9784
static Context globalContext_;
Context & globalContext() {
static Context globalContext_;
return globalContext_;
}

Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ namespace at {
// tensor and a [0] size values tensor for such an empty tensor. However,
// we don't currently support zero-size dimensions, so we can't actually
// do this; so we just allocate zero-size tensors for everything.
SparseTensorImpl::SparseTensorImpl(at::Backend backend, at::ScalarType scalar_type)
: TensorImpl(backend, scalar_type, nullptr, false)
SparseTensorImpl::SparseTensorImpl(Type * type)
: TensorImpl(type, nullptr)
, size_{0}
, sparseDims_(1)
, denseDims_(0)
, indices_(globalContext().getTypeOpt(toDense(backend), ScalarType::Long)->tensor())
, values_(globalContext().getTypeOpt(toDense(backend), scalar_type)->tensor()) {
AT_ASSERT(backend == Backend::SparseCPU || backend == Backend::SparseCUDA);
, indices_(type->toDense().toScalarType(ScalarType::Long).tensor())
, values_(type->toDense().tensor()) {
AT_ASSERT(type->is_sparse());
}

IntList SparseTensorImpl::sizes() const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/SparseTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct AT_API SparseTensorImpl : public TensorImpl {

public:
// Public for now...
explicit SparseTensorImpl(at::Backend, at::ScalarType);
explicit SparseTensorImpl(Type * type);

int64_t nnz() const { return nnz_; }
int64_t sparseDims() const { return sparseDims_; }
Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/StorageImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ StorageImpl::StorageImpl(
allocator,
resizable) {}

Type& StorageImpl::type() {
if (data_ptr.device().is_cuda()) {
return globalContext().getType(Backend::CUDA, scalar_type);
namespace detail {
Backend get_backend(StorageImpl* storage_impl) {
if (storage_impl->data_ptr.device().is_cuda()) {
return Backend::CUDA;
}
return globalContext().getType(Backend::CPU, scalar_type);
return Backend::CPU;
}

} // namespace detail
} // namespace at
5 changes: 3 additions & 2 deletions aten/src/ATen/StorageImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ struct AT_API StorageImpl : public Retainable {
return at::elementSize(scalar_type);
}

Type& type();

//TODO: Rename to size() and size to size_
size_t get_size() const {
return size;
Expand All @@ -112,4 +110,7 @@ struct AT_API StorageImpl : public Retainable {
}
};

namespace detail {
AT_API Backend get_backend(StorageImpl* storage_impl);
}
} // namespace at
13 changes: 0 additions & 13 deletions aten/src/ATen/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,10 @@

#include <ATen/Tensor.h>
#include <ATen/optional.h>
#include <ATen/Context.h>

#include <ATen/detail/VariableHooksInterface.h>

#include <TH/THTensor.hpp>

namespace at {

Type& TensorImpl::type() const {
Type* base_type = &globalContext().getType(backend_, scalar_type_);
if (is_variable_) {
return detail::getVariableHooks().getVariableType(*base_type);
} else {
return *base_type;
}
}

Tensor& TensorImpl::grad() {
AT_ERROR("grad is not implemented for Tensor");
}
Expand Down
18 changes: 6 additions & 12 deletions aten/src/ATen/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@ struct Tensor;

namespace at {
struct AT_API TensorImpl : public Retainable {
explicit TensorImpl(Backend backend, ScalarType scalar_type, THTensor * tensor, bool is_variable)
: backend_(backend), scalar_type_(scalar_type), is_variable_(is_variable), tensor(tensor) {}
explicit TensorImpl(Type * type, THTensor * tensor)
: type_(type), tensor(tensor) {}

virtual ~TensorImpl();

virtual void release_resources() override;

// The implementation of this method will have to be hoisted out and
// hooked in, so that Caffe2 doesn't need to know about Context
// TODO: This really really needs to be inlined.
Type & type() const;

Type & type() const {
return *type_;
}
const char * toString() const;
virtual IntList sizes() const;
virtual IntList strides() const;
Expand Down Expand Up @@ -93,12 +91,8 @@ struct AT_API TensorImpl : public Retainable {
virtual void set_data(Tensor new_data);

protected:
Backend backend_;
// INVARIANT: When storage is non-null, this scalar type must
// agree with the scalar type in storage
ScalarType scalar_type_;
bool is_variable_ = false;
bool is_wrapped_number_ = false;
Type * type_;
public:
THTensor * tensor;
};
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/UndefinedTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace at {

// should this use the globalContext? Can it get a context passed in somehow?
UndefinedTensor::UndefinedTensor()
: TensorImpl(Backend::Undefined, ScalarType::Undefined, nullptr, /* is variable */ false) {
: TensorImpl(&(globalContext().getType(Backend::Undefined,ScalarType::Undefined)), nullptr) {
}

IntList UndefinedTensor::sizes() const {
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/Utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "ATen/ATenGeneral.h"
#include "ATen/StorageImpl.h"
#include "ATen/ArrayRef.h"
#include "ATen/Error.h"
#include "ATen/UndefinedTensor.h"
Expand All @@ -24,12 +25,12 @@ AT_API int _crash_if_asan(int);

template <typename T, typename Base>
static inline T* checked_cast_storage(Base* expr, const char * name, int pos, Backend backend, ScalarType scalar_type) {
if (expr->pImpl()->type().backend() != backend) {
AT_ERROR("Expected object of backend ", backend, " but got backend ", expr->pImpl()->type().backend(),
if (at::detail::get_backend(expr->pImpl()) != backend) {
AT_ERROR("Expected object of backend ", backend, " but got backend ", at::detail::get_backend(expr->pImpl()),
" for argument #", pos, " '", name, "'");
}
if (expr->pImpl()->type().scalarType() != scalar_type) {
AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr->pImpl()->type().scalarType(),
if (expr->pImpl()->scalar_type != scalar_type) {
AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr->pImpl()->scalar_type,
" for argument #", pos, " '", name, "'");
}
// NB: We're getting rid of derived types soon!
Expand Down
5 changes: 0 additions & 5 deletions aten/src/ATen/detail/VariableHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <ATen/Registry.h>
#include <ATen/Error.h>
#include <ATen/ScalarType.h>
#include <ATen/Type.h>

namespace at {
class Context;
Expand All @@ -26,10 +25,6 @@ struct AT_API VariableHooksInterface {
// squelch -Werror=non-virtual-dtor
virtual ~VariableHooksInterface() {}

virtual Type& getVariableType(const at::Type& baseType) const {
AT_ERROR("cannot getVariableType without libtorch");
}

virtual void registerVariableTypeFor(Context*, Backend backend, ScalarType scalar_type) const {
// no-op if Variable not available; it'll get handled (if at all) when
// libtorch.so gets loaded
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ SparseTensor new_sparse(const SparseType& dtype) {
AT_ASSERT(!dtype.is_variable());
AT_ASSERT(dtype.is_sparse());
// TODO: Hmm... this const_cast business seems a bit dodgy
return SparseTensor(new SparseTensorImpl(dtype.backend(), dtype.scalarType()), /* retain */ false);
return SparseTensor(new SparseTensorImpl(const_cast<SparseType*>(&dtype)), /* retain */ false);
}

/*** Helper methods ***/
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/templates/TensorDerived.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace detail {
}

${Tensor}::${Tensor}(${THTensor} * tensor)
: TensorImpl(Backend::${Backend}, ScalarType::${ScalarName}, tensor, /* is variable */ false)
: TensorImpl(&globalContext().getType(Backend::${Backend},ScalarType::${ScalarName}), tensor)
{}

${TensorDenseOrSparse}
Expand Down
75 changes: 36 additions & 39 deletions aten/src/TH/THTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ struct THTensor
return sizes_.size();
}

at::ScalarType scalar_type() const {
return storage_->scalar_type;
}

ptrdiff_t storage_offset() const {
return storage_offset_;
}
Expand Down Expand Up @@ -127,41 +123,6 @@ inline THStorage* THTensor_getStoragePtr(const THTensor* tensor) {
return tensor->storage_;
}

#include "generic/THTensorFastGetSet.hpp"
#include "THGenerateAllTypes.h"

inline void THTensor_resizeDim(THTensor* tensor, int64_t ndim) {
// NB: This is *truly* a resize; calling code (e.g., squeeze)
// assumes that old values are preserved
tensor->is_zero_dim_ = bool(ndim == 0);
tensor->sizes_.resize(ndim);
tensor->strides_.resize(ndim);
}

inline void THTensor_setSizesAndStrides(THTensor* tensor, std::vector<int64_t>&& new_size, std::vector<int64_t>&& new_stride) {
tensor->sizes_ = std::move(new_size);
tensor->strides_ = std::move(new_stride);
}

inline void THTensor_setSizeAtDim(THTensor* tensor, int dim, int64_t new_size) {
tensor->sizes_[dim] = new_size;
}

inline void THTensor_setStrideAtDim(THTensor* tensor, int dim, int64_t new_stride) {
tensor->strides_[dim] = new_stride;
}

inline void THTensor_setStorageOffset(THTensor* tensor, ptrdiff_t storage_offset) {
tensor->storage_offset_ = storage_offset;
}

// NB: Steals ownership of storage
inline void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage) {
// Caffe2 might have tensors whose storages are null, but we
// don't allow it in PyTorch.
AT_ASSERT(storage);
tensor->storage_ = storage;
}

inline bool THTensor_isZeroDim(const THTensor *tensor) {
return tensor->is_zero_dim_;
Expand Down Expand Up @@ -209,6 +170,42 @@ inline int64_t THTensor_sizeLegacyNoScalars(const THTensor *self, int dim)
return THTensor_isZeroDim(self) ? 1 : self->size(dim);
}

#include "generic/THTensorFastGetSet.hpp"
#include "THGenerateAllTypes.h"

inline void THTensor_resizeDim(THTensor* tensor, int64_t ndim) {
// NB: This is *truly* a resize; calling code (e.g., squeeze)
// assumes that old values are preserved
tensor->is_zero_dim_ = bool(ndim == 0);
tensor->sizes_.resize(ndim);
tensor->strides_.resize(ndim);
}

inline void THTensor_setSizesAndStrides(THTensor* tensor, std::vector<int64_t>&& new_size, std::vector<int64_t>&& new_stride) {
tensor->sizes_ = std::move(new_size);
tensor->strides_ = std::move(new_stride);
}

inline void THTensor_setSizeAtDim(THTensor* tensor, int dim, int64_t new_size) {
tensor->sizes_[dim] = new_size;
}

inline void THTensor_setStrideAtDim(THTensor* tensor, int dim, int64_t new_stride) {
tensor->strides_[dim] = new_stride;
}

inline void THTensor_setStorageOffset(THTensor* tensor, ptrdiff_t storage_offset) {
tensor->storage_offset_ = storage_offset;
}

// NB: Steals ownership of storage
inline void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage) {
// Caffe2 might have tensors whose storages are null, but we
// don't allow it in PyTorch.
AT_ASSERT(storage);
tensor->storage_ = storage;
}

TH_API void THTensor_free(THTensor *self);
TH_CPP_API at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
at::IntList newshape);
10 changes: 5 additions & 5 deletions aten/src/TH/THTensorApply.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@
TENSOR##_size = 1; \
TENSOR##_stride = 1; \
for(TENSOR##_i = THTensor_nDimensionLegacyAll(TENSOR)-1; TENSOR##_i >= 0; TENSOR##_i--) { \
if(TENSOR->size(TENSOR##_i) != 1) { \
if(TENSOR->stride(TENSOR##_i) == TENSOR##_size && TENSOR##_i != DIM) \
TENSOR##_size *= TENSOR->size(TENSOR##_i); \
if(THTensor_sizeLegacyNoScalars(TENSOR, TENSOR##_i) != 1) { \
if(THTensor_strideLegacyNoScalars(TENSOR, TENSOR##_i) == TENSOR##_size && TENSOR##_i != DIM) \
TENSOR##_size *= THTensor_sizeLegacyNoScalars(TENSOR, TENSOR##_i); \
else{ \
TENSOR##_contiguous = 0; \
break; \
Expand All @@ -70,8 +70,8 @@
TENSOR##_strides = TENSOR##_counter + 2*TENSOR##_dim; \
TH_TENSOR_dim_index = TENSOR##_dim-1; \
TENSOR##_dimOffset = (DIM == THTensor_nDimensionLegacyAll(TENSOR)-1) ? &TENSOR##_i : &TENSOR##_counter[DIM]; \
TENSOR##_sizes[TH_TENSOR_dim_index] = TENSOR->size(THTensor_nDimensionLegacyAll(TENSOR)-1); \
TENSOR##_strides[TH_TENSOR_dim_index] = TENSOR->stride(THTensor_nDimensionLegacyAll(TENSOR)-1); \
TENSOR##_sizes[TH_TENSOR_dim_index] = THTensor_sizeLegacyNoScalars(TENSOR, THTensor_nDimensionLegacyAll(TENSOR)-1); \
TENSOR##_strides[TH_TENSOR_dim_index] = THTensor_strideLegacyNoScalars(TENSOR, THTensor_nDimensionLegacyAll(TENSOR)-1); \
/* TENSOR##_counter tracks where we are in the storage. The offset into the */ \
/* storage is given by storage_offset + (i * j), where i is the stride */ \
/* vector and j is tensor_counter vector. This sets the starting position for the loop. */ \
Expand Down
Loading