Skip to content

Merge from upstream #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Aug 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8fdba4e
Move all operator<< overloads out of the global namespace. (#10546)
ezyang Aug 16, 2018
fdd2b9b
Add DataType alias
ezyang Aug 16, 2018
c6facc2
Add conversions between DataType and ScalarType.
ezyang Aug 16, 2018
130881f
Delete build_caffe2.sh, replace with build_libtorch.py (#10508)
anderspapitto Aug 16, 2018
00f2731
Merge THTensor into TensorImpl
gchanan Aug 16, 2018
319fefe
Support benchmark on windows machines
sf-wind Aug 16, 2018
4be4b4c
Remove weight from input of onnxifi backend op (#10575)
Aug 16, 2018
40a0704
Adding new allreduce bcube routines to ops supported by gloo (#10494)
kirteshpatil Aug 16, 2018
d6f3c88
Revert D9076734: Split storage from tensor
Aug 16, 2018
ef15bb8
remove implicit conversion from gpu to cpu (#10553)
Aug 16, 2018
488ea82
Additional changes to make GPU builds work (#10507)
orionr Aug 16, 2018
342517e
Back out "Add aten_op to caffe2 onnx (python) backend" (#10589)
bddppq Aug 16, 2018
afd7477
Add ``buffers()``, ``named_buffers()`` methods. (#10554)
jma127 Aug 16, 2018
67c6d93
Tune minimal work size (#10599)
Aug 17, 2018
c101a57
Build mechanism for custom operators (#10226)
goldsborough Aug 17, 2018
f1d40ef
build_pytorch_libs.sh: use MAX_JOBS rather than NUM_JOBS (#10600)
anderspapitto Aug 17, 2018
3578909
Remove unused code base for distributed training (#10282)
heslami Aug 17, 2018
6667d55
Disallow input filler for GatherRangesOp (#10592)
csummersea Aug 17, 2018
0aefb9f
Update onnx to onnx/onnx@7848f1e (#10613)
onnxbot Aug 17, 2018
cc53807
group conv with NHWC layout (#10585)
jspark1105 Aug 17, 2018
5122250
Add AT_CORE_EXPORT and AT_CORE_IMPORT. (#10602)
tolia-msft Aug 17, 2018
82a5a4e
Merge remote-tracking branch 'upstream/master'
iotamudelta Aug 17, 2018
ff3a481
fix python interpreter can not be found (#10543)
Aug 17, 2018
03982fb
Fix subgraph cutting wrt recent external_input change in nomnigraph (…
Aug 17, 2018
31c7a32
Include aten_op by default in caffe2
bddppq Aug 17, 2018
e190505
Adding support for inlining if branches (#10084)
Aug 17, 2018
ff440b6
Revert D9378844: [pytorch][PR] fix python interpreter can not be found
ezyang Aug 17, 2018
bd9ab65
fix compile error in math_hip.cc from new Im2Col/Col2Im interface (#1…
jspark1105 Aug 17, 2018
e26bb17
Merge remote-tracking branch 'upstream/master'
iotamudelta Aug 17, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .jenkins/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
# This environment variable enabled HCC Optimizations that speed up the linking stage.
# https://github.com/RadeonOpenCompute/hcc#hcc-with-thinlto-linking
export KMTHINLTO=1

# Need the libc++1 and libc++abi1 libraries to allow torch._C to load at runtime
sudo apt-get install libc++1
sudo apt-get install libc++abi1

python tools/amd_build/build_pytorch_amd.py
USE_ROCM=1 python setup.py install --user
exit 0
Expand Down Expand Up @@ -118,5 +118,9 @@ if [[ "$BUILD_TEST_LIBTORCH" == "1" ]]; then
echo "Building libtorch"
# NB: Install outside of source directory (at the same level as the root
# pytorch folder) so that it doesn't get cleaned away prior to docker push.
WERROR=1 VERBOSE=1 tools/cpp_build/build_caffe2.sh "$PWD/../cpp-build"
BUILD_LIBTORCH_PY=$PWD/tools/build_libtorch.py
mkdir -p ../cpp-build/caffe2
pushd ../cpp-build/caffe2
WERROR=1 VERBOSE=1 DEBUG=1 python $BUILD_LIBTORCH_PY
popd
fi
8 changes: 6 additions & 2 deletions .jenkins/pytorch/macos-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ test_cpp_api() {
#
CPP_BUILD="$PWD/../cpp-build"
rm -rf $CPP_BUILD
mkdir -p $CPP_BUILD
WERROR=1 VERBOSE=1 tools/cpp_build/build_caffe2.sh "$CPP_BUILD"
mkdir -p $CPP_BUILD/caffe2

BUILD_LIBTORCH_PY=$PWD/tools/build_libtorch.py
pushd $CPP_BUILD/caffe2
WERROR=1 VERBOSE=1 DEBUG=1 python $BUILD_LIBTORCH_PY
popd

python tools/download_mnist.py --quiet -d test/cpp/api/mnist

Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments")
endif()
if ((APPLE AND (NOT ("${CLANG_VERSION_STRING}" VERSION_LESS "9.0")))
OR (CMAKE_COMPILER_IS_GNUCXX
OR (CMAKE_COMPILER_IS_GNUCXX
AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 AND NOT APPLE)))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
endif()
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,7 @@
if (self_->dim() == 0) {
throw std::runtime_error("Input must be 1-d or 2-d");
}
${THTensor}_diag(${state,}result_->tensor, self_->tensor, diagonal);
${THTensor}_diag(${state,}result_, self_, diagonal);
result_->maybe_zero_dim(self_->dim() == 0);
]]
[[
Expand Down Expand Up @@ -2986,7 +2986,7 @@
- arg: real tol
default: -1
aten_custom_call: |
${THTensor}_pstrf(res1_->tensor, res2_->tensor, self_->tensor, (upper) ? "U" : "L", tol_);
${THTensor}_pstrf(res1_, res2_, self_, (upper) ? "U" : "L", tol_);
res2 -= 1; // LAPACK returns 1-indexed pivots
]]
[[
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/Registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class AT_API Registry {
};

template <class SrcType, class ObjectPtrType, class... Args>
class Registerer {
class AT_API Registerer {
public:
Registerer(
const SrcType& key,
Expand Down
20 changes: 15 additions & 5 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace {
// we don't currently support zero-size dimensions, so we can't actually
// do this; so we just allocate zero-size tensors for everything.
SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, at::ScalarType scalar_type)
: TensorImpl(type_id, scalar_type, nullptr, false)
: TensorImpl(type_id, scalar_type, false)
, size_{0}
, sparseDims_(1)
, denseDims_(0)
Expand All @@ -44,6 +44,14 @@ IntList SparseTensorImpl::sizes() const {
IntList SparseTensorImpl::strides() const {
AT_ERROR("sparse tensors do not have strides");
}
int64_t SparseTensorImpl::size(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
return size_[d];
}
int64_t SparseTensorImpl::stride(int64_t d) const {
AT_ERROR("sparse tensors do not have strides");
}

int64_t SparseTensorImpl::dim() const {
return sparseDims_ + denseDims_;
}
Expand All @@ -54,13 +62,15 @@ TensorImpl* SparseTensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
" changing dimensionality via maybe_zero_dim");
return this;
}
void * SparseTensorImpl::unsafeGetTH(bool retain) {
AT_ERROR("unsafeGetTH not supported for new style TensorImpl");
}
std::unique_ptr<Storage> SparseTensorImpl::storage() {
AT_ERROR("sparse tensors do not have storage");
}

at::StorageImpl* SparseTensorImpl::storageImpl() const {
AT_ERROR("sparse tensors do not have storage");
}
int64_t SparseTensorImpl::storage_offset() const {
AT_ERROR("sparse tensors do not have storage");
}
void SparseTensorImpl::set_indices_and_values(const Tensor& indices, const Tensor& values) {
// TODO: Explicit empty test is needed because we don't handle size zero
// dimensions at the moment
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/SparseTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ struct AT_API SparseTensorImpl : public TensorImpl {

IntList sizes() const override;
IntList strides() const override;
int64_t size(int64_t d) const override;
int64_t stride(int64_t d) const override;

int64_t dim() const override;
TensorImpl* maybe_zero_dim(bool condition_when_zero_dim) override;
void * unsafeGetTH(bool retain) override;
std::unique_ptr<Storage> storage() override;
at::StorageImpl* storageImpl() const override;
int64_t storage_offset() const override;

// Some ops do some manual size fiddling.
// TODO: Figure out a more safe way to provide this functionality
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/StorageImpl.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#pragma once

#include <ATen/Scalar.h>

#include <ATen/Allocator.h>
#include <ATen/ScalarType.h>
#include <ATen/ScalarTypeUtils.h>
Expand Down
76 changes: 47 additions & 29 deletions aten/src/ATen/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,59 +59,77 @@ void Tensor::backward(
pImpl->backward(std::move(gradient), keep_graph, create_graph);
}

TensorImpl::TensorImpl(TensorTypeId type_id, ScalarType scalar_type)
: type_id_(type_id), scalar_type_(scalar_type) {
auto type = &globalContext().getType(tensorTypeIdToBackend(type_id), scalar_type);
auto storage = type->storage(true);
StorageImpl* storage_impl = storage->pImpl();
storage_impl->retain();
tensor = new THTensor(storage_impl);
TensorImpl::TensorImpl(TensorTypeId type_id, ScalarType scalar_type, bool is_variable)
: TensorImpl(nullptr, type_id, scalar_type, is_variable) {
// UndefinedTensors and SparseTensors don't have storages.
if (type_id != UndefinedTensorId() && scalar_type != ScalarType::Undefined
&& type_id != SparseCPUTensorId() && type_id != SparseCUDATensorId()) {
auto type = &globalContext().getType(tensorTypeIdToBackend(type_id), scalar_type);
auto storage = type->storage(true);
storage_ = storage->pImpl();
storage_->retain();
}
}

TensorImpl::TensorImpl(StorageImpl* storage, TensorTypeId type_id, bool is_variable)
: TensorImpl(storage, type_id, storage->scalar_type(), is_variable) {}

TensorImpl::TensorImpl(StorageImpl* storage, TensorTypeId type_id, ScalarType scalar_type, bool is_variable)
: storage_(storage),
storage_offset_(0),
sizes_{0},
strides_{1},
type_id_(type_id),
scalar_type_(scalar_type),
is_variable_(is_variable) {}

TensorImpl::~TensorImpl() {
if (tensor) tensor->release();
if (storage_) {
storage_->release();
storage_ = nullptr;
}
}

IntList TensorImpl::sizes() const {
// NB: dim in tensor is not synchronized with THTensor, so it's
// important to apply dim here
return IntList(THTensor_getSizePtr(tensor), dim());
return sizes_;
}

IntList TensorImpl::strides() const {
// NB: dim in tensor is not synchronized with THTensor, so it's
// important to apply dim here
return IntList(THTensor_getStridePtr(tensor), dim());
return strides_;
}

void TensorImpl::release_resources() {
if (tensor) {
tensor->release();
tensor = nullptr;
if (storage_) {
storage_->release();
storage_ = nullptr;
}
}

int64_t TensorImpl::dim() const {
return tensor->dim();
return sizes_.size();
}

TensorImpl* TensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
AT_CHECK(tensor, "TensorImpl without THTensor in maybe_zero_dim");
THTensor_maybe_zero_dim(tensor, condition_when_zero_dim);
return this;
int64_t TensorImpl::size(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
return sizes_[d];
}

int64_t TensorImpl::stride(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
return strides_[d];
}

void * TensorImpl::unsafeGetTH(bool retain) {
if (retain) {
tensor->retain();
TensorImpl* TensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
bool set_zero_dim = condition_when_zero_dim && this->sizes().size() == 1 && this->size(0) == 1;
if (set_zero_dim) {
THTensor_resizeDim(this, 0);
}
return tensor;
return this;
}

std::unique_ptr<Storage> TensorImpl::storage() {
StorageImpl* storage = tensor->storage_;
storage->retain();
return std::unique_ptr<Storage>(new Storage(storage));
storage_->retain();
return std::unique_ptr<Storage>(new Storage(storage_));
}

} // namespace at
57 changes: 49 additions & 8 deletions aten/src/ATen/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#include <atomic>
#include <memory>

#include "ATen/Retainable.h"
#include "ATen/ScalarType.h"
#include "ATen/StorageImpl.h"
#include "ATen/core/optional.h"
#include "ATen/core/TensorTypeId.h"
#include "ATen/core/TensorTypeIdRegistration.h"
Expand All @@ -20,9 +19,8 @@ struct Tensor;

namespace at {
struct AT_API TensorImpl : public Retainable {
explicit TensorImpl(TensorTypeId type_id, ScalarType scalar_type, THTensor * tensor, bool is_variable)
: type_id_(type_id), scalar_type_(scalar_type), is_variable_(is_variable), tensor(tensor) {}
TensorImpl(TensorTypeId type_id, ScalarType scalar_type);
TensorImpl(TensorTypeId type_id, ScalarType scalar_type, bool is_variable);
TensorImpl(StorageImpl* storage, TensorTypeId type_id, bool is_variable);

virtual ~TensorImpl();

Expand All @@ -37,7 +35,6 @@ struct AT_API TensorImpl : public Retainable {
virtual IntList sizes() const;
virtual IntList strides() const;
virtual int64_t dim() const;
virtual void * unsafeGetTH(bool retain);
virtual std::unique_ptr<Storage> storage();
friend struct Type;

Expand Down Expand Up @@ -95,14 +92,58 @@ struct AT_API TensorImpl : public Retainable {

virtual void set_data(Tensor new_data);

// TODO: make these protected
// Note: storage->size() may be greater than the recorded size
// of a tensor
at::StorageImpl* storage_;
int64_t storage_offset_;

std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;

template <typename T>
inline T * data() const {
return storageImpl()->data<T>() + storage_offset_;
}

template <typename T>
inline T * unsafe_data() const {
return storageImpl()->unsafe_data<T>() + storage_offset_;
}

inline at::ScalarType scalar_type() const {
return scalar_type_;
}

virtual int64_t storage_offset() const {
return storage_offset_;
}

// represents that numel() == 0.
inline bool is_empty() const {
for (int64_t i = 0; i < dim(); ++i) {
if (sizes()[i] == 0) {
return true;
}
}
return false;
}

virtual int64_t size(int64_t d) const;
virtual int64_t stride(int64_t d) const;

// TODO: get rid of this.
virtual at::StorageImpl* storageImpl() const { return storage_; }

protected:
TensorTypeId type_id_;
// INVARIANT: When storage is non-null, this scalar type must
// agree with the scalar type in storage
ScalarType scalar_type_;
bool is_variable_ = false;
bool is_wrapped_number_ = false;
public:
THTensor * tensor;

private:
TensorImpl(StorageImpl* storage, TensorTypeId type_id, ScalarType scalar_type, bool is_variable);
};
} // namespace at
21 changes: 17 additions & 4 deletions aten/src/ATen/UndefinedTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,37 @@ namespace at {

// should this use the globalContext? Can it get a context passed in somehow?
UndefinedTensor::UndefinedTensor()
: TensorImpl(UndefinedTensorId(), ScalarType::Undefined, nullptr, /* is variable */ false) {
: TensorImpl(UndefinedTensorId(), ScalarType::Undefined, /* is variable */ false) {
}

IntList UndefinedTensor::sizes() const {
AT_ERROR("sizes() called on undefined Tensor");
}

int64_t UndefinedTensor::size(int64_t d) const {
AT_ERROR("size(dim) called on an undefined Tensor");
}

int64_t UndefinedTensor::stride(int64_t d) const {
AT_ERROR("stride(dim) called on an undefined Tensor");
}

int64_t UndefinedTensor::dim() const {
AT_ERROR("dim() called on undefined Tensor");
}

void * UndefinedTensor::unsafeGetTH(bool retain) {
AT_ERROR("unsafeGetTH(bool retain) called on undefined Tensor");
}
std::unique_ptr<Storage> UndefinedTensor::storage() {
AT_ERROR("storage() called on undefined Tensor");
}

at::StorageImpl* UndefinedTensor::storageImpl() const {
AT_ERROR("storageImpl() called on an undefined Tensor");
}

int64_t UndefinedTensor::storage_offset() const {
AT_ERROR("storage_offset() called on an undefined Tensor");
}

IntList UndefinedTensor::strides() const {
AT_ERROR("strides() called on undefined Tensor");
}
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/UndefinedTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ struct AT_API UndefinedTensor final : public TensorImpl {
}
IntList sizes() const override;
IntList strides() const override;
int64_t size(int64_t d) const override;
int64_t stride(int64_t d) const override;
int64_t dim() const override;
void * unsafeGetTH(bool retain) override;
std::unique_ptr<Storage> storage() override;
at::StorageImpl* storageImpl() const override;
int64_t storage_offset() const override;
private:
UndefinedTensor();
static UndefinedTensor _singleton;
Expand Down
Loading