Skip to content

Commit 460b96f

Browse files
authored
Merge pull request #97 from iotamudelta/master
Merge from upstream.
2 parents 3a924aa + 943242c commit 460b96f

File tree

92 files changed

+5694
-895
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+5694
-895
lines changed

.jenkins/pytorch/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ popd
2020
# if you're not careful. Check this if you made some changes and the
2121
# ASAN test is not working
2222
if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then
23-
export ASAN_OPTIONS=detect_leaks=0:symbolize=1
23+
export ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true
2424
# We suppress the vptr volation, since we have separate copies of
2525
# libprotobuf in both libtorch.so and libcaffe2.so, and it causes
2626
# the following problem:

aten/src/ATen/Storage.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ struct AT_API Storage {
2121
Storage(const Storage&) = delete;
2222
Storage(Storage&&) = delete;
2323
Storage(const Storage&&) = delete;
24+
void set_pImpl(StorageImpl* storage_impl) {
25+
storage_impl_ = storage_impl;
26+
}
2427
StorageImpl* pImpl() {
2528
return storage_impl_;
2629
}

aten/src/ATen/StorageImpl.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ StorageImpl::StorageImpl(
99
at::DataPtr data_ptr,
1010
at::Allocator* allocator,
1111
bool resizable)
12-
: scalar_type(scalar_type),
13-
data_ptr(std::move(data_ptr)),
14-
size(size),
15-
resizable(resizable),
16-
allocator(allocator),
17-
finalizer(nullptr) {}
12+
: scalar_type_(scalar_type),
13+
data_ptr_(std::move(data_ptr)),
14+
size_(size),
15+
resizable_(resizable),
16+
allocator_(allocator),
17+
finalizer_(nullptr) {}
1818

1919
StorageImpl::StorageImpl(
2020
at::ScalarType scalar_type,
@@ -30,7 +30,7 @@ StorageImpl::StorageImpl(
3030

3131
namespace detail {
3232
Backend get_backend(StorageImpl* storage_impl) {
33-
if (storage_impl->data_ptr.device().is_cuda()) {
33+
if (storage_impl->data_ptr().device().is_cuda()) {
3434
return Backend::CUDA;
3535
}
3636
return Backend::CPU;

aten/src/ATen/StorageImpl.h

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,27 @@ namespace at {
4141
struct Type;
4242

4343
struct AT_API StorageImpl : public Retainable {
44-
44+
public:
4545
StorageImpl() = delete;
4646
virtual ~StorageImpl() {};
4747
StorageImpl(at::ScalarType, ptrdiff_t, at::DataPtr, at::Allocator*, bool);
4848
StorageImpl(at::ScalarType, ptrdiff_t, at::Allocator*, bool);
49-
at::ScalarType scalar_type;
50-
at::DataPtr data_ptr;
51-
ptrdiff_t size;
52-
bool resizable;
53-
at::Allocator* allocator;
54-
std::unique_ptr<THFinalizer> finalizer;
5549
StorageImpl(StorageImpl&) = delete;
5650
StorageImpl(const StorageImpl&) = delete;
57-
StorageImpl(StorageImpl&&) = delete;
51+
// NB: Don't move ref count!
52+
StorageImpl(StorageImpl&& other) = delete;
5853
StorageImpl(const StorageImpl&&) = delete;
54+
StorageImpl& operator=(StorageImpl&& other) = delete;
5955

6056
// TODO: Rename this into th_data, and move it out of the class;
6157
// the real data shouldn't call th::from_type
6258
template <typename T>
6359
inline T* data() const {
6460
auto scalar_type_T = at::CTypeToScalarType<th::from_type<T>>::to();
65-
if (scalar_type != scalar_type_T) {
61+
if (scalar_type_ != scalar_type_T) {
6662
AT_ERROR(
6763
"Attempt to access StorageImpl having data type ",
68-
at::toString(scalar_type),
64+
at::toString(scalar_type_),
6965
" as data type ",
7066
at::toString(scalar_type_T));
7167
}
@@ -74,40 +70,72 @@ struct AT_API StorageImpl : public Retainable {
7470

7571
template <typename T>
7672
inline T* unsafe_data() const {
77-
return static_cast<T*>(this->data_ptr.get());
73+
return static_cast<T*>(this->data_ptr_.get());
7874
}
7975

8076
void release_resources() {
81-
if (finalizer) {
82-
(*finalizer)();
77+
if (finalizer_) {
78+
(*finalizer_)();
8379
}
84-
finalizer = nullptr;
85-
data_ptr.clear();
80+
finalizer_ = nullptr;
81+
data_ptr_.clear();
8682
}
8783

8884
void operator=(const StorageImpl&) = delete;
8985

9086
virtual size_t elementSize() const {
91-
return at::elementSize(scalar_type);
87+
return at::elementSize(scalar_type_);
9288
}
9389

94-
//TODO: Rename to size() and size to size_
95-
size_t get_size() const {
96-
return size;
90+
Type& type();
91+
92+
// TODO: Rename to size() and size to size_
93+
ptrdiff_t size() const {
94+
return size_;
95+
};
96+
void set_size(ptrdiff_t size) {
97+
size_ = size;
98+
};
99+
bool resizable() const {
100+
return resizable_;
101+
};
102+
at::DataPtr& data_ptr() {
103+
return data_ptr_;
104+
};
105+
void set_data_ptr(at::DataPtr&& data_ptr) {
106+
data_ptr_ = std::move(data_ptr);
97107
};
98108
void* data() {
99-
return data_ptr.get();
109+
return data_ptr_.get();
100110
};
101111
const void* data() const {
102-
return data_ptr.get();
112+
return data_ptr_.get();
113+
};
114+
at::Allocator* allocator() {
115+
return allocator_;
116+
};
117+
at::ScalarType& scalar_type() {
118+
return scalar_type_;
119+
};
120+
const at::Allocator* allocator() const {
121+
return allocator_;
103122
};
104-
105123
int getDevice() const {
106-
return data_ptr.device().index();
124+
return data_ptr_.device().index();
107125
}
108-
void set_resizable(bool resizable_) {
109-
resizable = resizable_;
126+
void set_resizable(bool resizable) {
127+
resizable_ = resizable;
110128
}
129+
130+
private:
131+
at::ScalarType scalar_type_;
132+
at::DataPtr data_ptr_;
133+
ptrdiff_t size_;
134+
bool resizable_;
135+
136+
public:
137+
at::Allocator* allocator_;
138+
std::unique_ptr<THFinalizer> finalizer_;
111139
};
112140

113141
namespace detail {

aten/src/ATen/THLongStorageView.h

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <ATen/StorageImpl.h>
34
#include "TH/TH.h"
45
#include "TH/THStorageFunctions.hpp"
56
#include "TH/THTypeConversion.hpp"
@@ -16,11 +17,11 @@ enum class THLongStorageViewKind {
1617
// used as an argument where THSize and THStride are passed into TH
1718
class THLongStorageView {
1819
public:
19-
operator THLongStorage*() {
20-
if (storage.size == 0 && zero_dim_to_null) {
20+
operator StorageImpl*() {
21+
if (storage.pImpl()->size() == 0 && zero_dim_to_null) {
2122
return nullptr;
2223
}
23-
return &storage;
24+
return storage.pImpl();
2425
}
2526

2627
/*
@@ -37,8 +38,7 @@ class THLongStorageView {
3738
*/
3839

3940
THLongStorageView(ArrayRef<int64_t> ref, THLongStorageViewKind kind)
40-
: storage(at::CTypeToScalarType<th::from_type<int64_t>>::to(), 0, getTHDefaultAllocator(), 0), zero_dim_to_null(false)
41-
{
41+
: storage(nullptr), zero_dim_to_null(false) {
4242
// zero_dim_to_one converts an empty ArrayRef into [1]
4343
// zero_dim_to_null converts an empty ArrayRef into a null THLongStorage
4444
bool zero_dim_to_one = false;
@@ -53,22 +53,33 @@ class THLongStorageView {
5353
break;
5454
}
5555

56-
if(zero_dim_to_one && ref.size() == 0) {
56+
if (zero_dim_to_one && ref.size() == 0) {
5757
// make storage of size 0 actually a 1-length storage with 1 element
5858
// so that our 0-dim tensors get allocated as 1-dim inside TH
59+
5960
one = 1;
60-
storage.data_ptr = {&one, kCPU}; // non-owning
61-
storage.size = 1;
61+
storage.set_pImpl(new StorageImpl(
62+
at::CTypeToScalarType<th::from_type<int64_t>>::to(),
63+
1,
64+
{&one, kCPU}, // non-owning
65+
nullptr,
66+
false));
6267
} else {
63-
storage.data_ptr = {const_cast<void*>(static_cast<const void*>(ref.data())), kCPU}; // non-owning
64-
storage.size = ref.size();
68+
storage.set_pImpl(new StorageImpl(
69+
at::CTypeToScalarType<th::from_type<int64_t>>::to(),
70+
ref.size(),
71+
{const_cast<void*>(static_cast<const void*>(ref.data())),
72+
kCPU}, // non-owning
73+
nullptr,
74+
false));
6575
}
66-
storage.scalar_type = at::CTypeToScalarType<th::from_type<int64_t>>::to();
67-
storage.set_resizable(false);
6876
}
6977
private:
7078
int64_t one;
71-
THLongStorage storage;
79+
// NB: The lifetime of objects like one are tied to the lifetime of an
80+
// instance of this class. That means if storage is used after an instance of
81+
// this class dies, it'll be corrupted.
82+
Storage storage;
7283
bool zero_dim_to_null;
7384
};
7485

aten/src/ATen/Utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ static inline T* checked_cast_storage(Base* expr, const char * name, int pos, Ba
2929
AT_ERROR("Expected object of backend ", backend, " but got backend ", at::detail::get_backend(expr->pImpl()),
3030
" for argument #", pos, " '", name, "'");
3131
}
32-
if (expr->pImpl()->scalar_type != scalar_type) {
33-
AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr->pImpl()->scalar_type,
32+
if (expr->pImpl()->scalar_type() != scalar_type) {
33+
AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr->pImpl()->scalar_type(),
3434
" for argument #", pos, " '", name, "'");
3535
}
3636
// NB: We're getting rid of derived types soon!

aten/src/ATen/core/C++17.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ template<class T> using remove_reference_t = std::remove_reference_t<T>;
8181
template<class T> using remove_cv_t = std::remove_cv_t<T>;
8282
template<class T> using result_of_t = std::result_of_t<T>;
8383
template<class T> using decay_t = std::decay_t<T>;
84+
template<class T> using remove_const_t = std::remove_const_t<T>;
8485
#else
8586
template<bool B, class T, class F> using conditional_t = typename std::conditional<B, T, F>::type;
8687
template<bool B, class T = void> using enable_if_t = typename std::enable_if<B, T>::type;
@@ -89,6 +90,7 @@ template<class T> using remove_reference_t = typename std::remove_reference<T>::
8990
template<class T> using remove_cv_t = typename std::remove_cv<T>::type;
9091
template<class T> using result_of_t = typename std::result_of<T>::type;
9192
template<class T> using decay_t = typename std::decay<T>::type;
93+
template<class T> using remove_const_t = typename std::remove_const<T>::type;
9294
#endif
9395

9496

aten/src/ATen/core/intrusive_ptr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include <ATen/core/intrusive_ptr.h>

0 commit comments

Comments
 (0)