Skip to content

Commit 33ebb58

Browse files
authored
Merge pull request #39 from iotamudelta/master
Merge from upstream
2 parents 3fadf87 + b29376c commit 33ebb58

File tree

90 files changed

+2019
-724
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+2019
-724
lines changed

.jenkins/pytorch/test.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
99

1010
echo "Testing pytorch"
1111

12+
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
13+
echo "Skipping ROCm tests for now"
14+
exit 0
15+
fi
16+
1217
# JIT C++ extensions require ninja.
1318
git clone https://github.com/ninja-build/ninja --quiet
1419
pushd ninja

README.md

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<p align="center"><img width="40%" src="docs/source/_static/img/pytorch-logo-dark.png" /></p>
1+
![PyTorch Logo](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/pytorch-logo-dark.png)
22

33
--------------------------------------------------------------------------------
44

@@ -34,32 +34,14 @@ See also the [ci.pytorch.org HUD](https://ezyang.github.io/pytorch-ci-hud/build/
3434

3535
At a granular level, PyTorch is a library that consists of the following components:
3636

37-
<table>
38-
<tr>
39-
<td><b> torch </b></td>
40-
<td> a Tensor library like NumPy, with strong GPU support </td>
41-
</tr>
42-
<tr>
43-
<td><b> torch.autograd </b></td>
44-
<td> a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch </td>
45-
</tr>
46-
<tr>
47-
<td><b> torch.nn </b></td>
48-
<td> a neural networks library deeply integrated with autograd designed for maximum flexibility </td>
49-
</tr>
50-
<tr>
51-
<td><b> torch.multiprocessing </b></td>
52-
<td> Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training. </td>
53-
</tr>
54-
<tr>
55-
<td><b> torch.utils </b></td>
56-
<td> DataLoader, Trainer and other utility functions for convenience </td>
57-
</tr>
58-
<tr>
59-
<td><b> torch.legacy(.nn/.optim) </b></td>
60-
<td> legacy code that has been ported over from torch for backward compatibility reasons </td>
61-
</tr>
62-
</table>
37+
| Component | Description |
38+
| ---- | --- |
39+
| **torch** | a Tensor library like NumPy, with strong GPU support |
40+
| **torch.autograd** | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
41+
| **torch.nn** | a neural networks library deeply integrated with autograd designed for maximum flexibility |
42+
| **torch.multiprocessing** | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
43+
| **torch.utils** | DataLoader, Trainer and other utility functions for convenience |
44+
| **torch.legacy(.nn/.optim)** | legacy code that has been ported over from torch for backward compatibility reasons |
6345

6446
Usually one uses PyTorch either as:
6547

@@ -72,7 +54,7 @@ Elaborating further:
7254

7355
If you use NumPy, then you have used Tensors (a.k.a ndarray).
7456

75-
<p align=center><img width="30%" src="docs/source/_static/img/tensor_illustration.png" /></p>
57+
![Tensor illustration](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/tensor_illustration.png)
7658

7759
PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerate
7860
compute by a huge amount.
@@ -99,7 +81,7 @@ from several research papers on this topic, as well as current and past work suc
9981
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date.
10082
You get the best of speed and flexibility for your crazy research.
10183

102-
<p align=center><img width="80%" src="docs/source/_static/img/dynamic_graph.gif" /></p>
84+
![Dynamic graph](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/dynamic_graph.gif)
10385

10486
### Python First
10587

aten/src/ATen/Retainable.h

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,52 @@ namespace at {
77
// base class for refcounted things, allows for collects of generic
88
// refcounted objects that include tensors
99
struct Retainable {
10-
Retainable(): refcount(1) {}
10+
Retainable(): refcount(1), weak_refcount(1) {}
1111
void retain() {
1212
++refcount;
1313
}
1414
void release() {
1515
if(--refcount == 0) {
16+
// If we know that this is the last reference then we can skip
17+
// all the decrements and release_resources().
18+
if (weak_refcount == 1) {
19+
delete this;
20+
} else {
21+
release_resources();
22+
weak_release();
23+
}
24+
}
25+
}
26+
void weak_retain() {
27+
++weak_refcount;
28+
}
29+
void weak_release() {
30+
if (--weak_refcount == 0) {
1631
delete this;
1732
}
1833
}
19-
int use_count() const {
34+
bool weak_lock() {
35+
for (;;) {
36+
auto current_refcount = refcount.load();
37+
if (current_refcount == 0) return false;
38+
if (refcount.compare_exchange_strong(current_refcount, current_refcount + 1)) break;
39+
}
40+
return true;
41+
}
42+
uint32_t use_count() const {
2043
return refcount.load();
2144
}
45+
uint32_t weak_use_count() const {
46+
return weak_refcount.load();
47+
}
48+
49+
virtual void release_resources() {};
2250
virtual ~Retainable() {}
2351
private:
24-
std::atomic<int> refcount;
52+
// INVARIANT: once refcount reaches 0 it can never go up
53+
// INVARIANT: weak_refcount = number of weak references + (refcount > 0 ? 1 : 0)
54+
std::atomic<uint32_t> refcount;
55+
std::atomic<uint32_t> weak_refcount;
2556
};
2657

2758
}

aten/src/ATen/TensorBase.h

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,62 @@
55

66
namespace at { namespace detail {
77

8-
// TensorBase is the base class for Tensor which handles the reference counting
9-
struct TensorBase {
10-
TensorBase(): TensorBase(UndefinedTensor::singleton(), false) {}
11-
TensorBase(TensorImpl * self, bool retain)
8+
// TensorBaseImpl is the base class for Tensor which handles the reference counting
9+
template<bool is_strong>
10+
struct TensorBaseImpl {
11+
TensorBaseImpl(): TensorBaseImpl(UndefinedTensor::singleton(), false) {}
12+
TensorBaseImpl(TensorImpl * self, bool should_retain)
1213
: pImpl(self) {
1314
if (pImpl == nullptr) {
14-
throw std::runtime_error("TensorBase with nullptr not supported");
15+
throw std::runtime_error("TensorBaseImpl with nullptr not supported");
16+
}
17+
if(should_retain && pImpl != UndefinedTensor::singleton()) {
18+
retain();
1519
}
16-
if(retain && pImpl != UndefinedTensor::singleton())
17-
pImpl->retain();
1820
}
19-
TensorBase(const TensorBase & rhs)
21+
TensorBaseImpl(const TensorBaseImpl & rhs)
2022
: pImpl(rhs.pImpl) {
21-
if (pImpl != UndefinedTensor::singleton())
22-
pImpl->retain();
23+
if (pImpl != UndefinedTensor::singleton()) {
24+
retain();
25+
}
2326
}
24-
TensorBase(TensorBase && rhs) noexcept
27+
TensorBaseImpl(TensorBaseImpl && rhs) noexcept
2528
: pImpl(rhs.pImpl) {
2629
rhs.pImpl = UndefinedTensor::singleton();
2730
}
28-
~TensorBase() {
29-
if (pImpl != UndefinedTensor::singleton())
30-
pImpl->release();
31+
~TensorBaseImpl() {
32+
if (pImpl != UndefinedTensor::singleton()) {
33+
release();
34+
}
3135
}
32-
TensorBase & operator=(TensorBase && rhs) & {
36+
TensorBaseImpl & operator=(TensorBaseImpl && rhs) & {
3337
rhs.swap(*this);
3438
return *this;
3539
}
36-
TensorBase & operator=(TensorBase const & rhs) & {
37-
//TensorBase ctor retains original rhs.pImpl
38-
//then rhs.pImpl is swapped with this->pImpl
39-
//finally TensorBase dtor releases rhs.pImpl, which was originally this->pImpl
40-
TensorBase(rhs).swap(*this);
41-
return *this;
40+
TensorBaseImpl & operator=(TensorBaseImpl const & rhs) & {
41+
//TensorBaseImpl ctor retains original rhs.pImpl
42+
//then rhs.pImpl is swapped with this->pImpl
43+
//finally TensorBaseImpl dtor releases rhs.pImpl, which was originally this->pImpl
44+
TensorBaseImpl(rhs).swap(*this);
45+
return *this;
4246
}
4347
int64_t dim() const {
44-
return pImpl->dim();
48+
if (is_strong) {
49+
return pImpl->dim();
50+
} else {
51+
AT_ERROR("Can't call dim() on a WeakTensor");
52+
}
4553
}
4654
void reset() {
47-
TensorBase().swap(*this);
55+
TensorBaseImpl().swap(*this);
4856
}
4957
void reset(TensorImpl * rhs) {
50-
TensorBase(rhs, true).swap(*this);
58+
TensorBaseImpl(rhs, true).swap(*this);
5159
}
52-
void reset(TensorImpl * rhs, bool retain) {
53-
TensorBase(rhs, retain).swap(*this );
60+
void reset(TensorImpl * rhs, bool should_retain) {
61+
TensorBaseImpl(rhs, should_retain).swap(*this );
5462
}
55-
void swap(TensorBase & rhs) {
63+
void swap(TensorBaseImpl & rhs) {
5664
TensorImpl * tmp = pImpl;
5765
pImpl = rhs.pImpl;
5866
rhs.pImpl = tmp;
@@ -75,6 +83,26 @@ struct TensorBase {
7583
//TODO(zach): sort out friend structes
7684
public:
7785
TensorImpl * pImpl;
86+
87+
private:
88+
void retain() {
89+
if (is_strong) {
90+
pImpl->retain();
91+
} else {
92+
pImpl->weak_retain();
93+
}
94+
}
95+
96+
void release() {
97+
if (is_strong) {
98+
pImpl->release();
99+
} else {
100+
pImpl->weak_release();
101+
}
102+
}
78103
};
79104

105+
using TensorBase = TensorBaseImpl<true>;
106+
using WeakTensorBase = TensorBaseImpl<false>;
107+
80108
}} // namespace at::detail

aten/src/ATen/Utils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212

1313
#if defined(__clang__)
1414
#define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero")))
15-
#define __ubsan_ignore_function__ __attribute__((no_sanitize("function")))
1615
#define __ubsan_ignore_vptr__ __attribute__((no_sanitize("vptr")))
1716
#else
1817
#define __ubsan_ignore_float_divide_by_zero__
19-
#define __ubsan_ignore_function__
2018
#define __ubsan_ignore_vptr__
2119
#endif
2220

aten/src/ATen/templates/Tensor.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ATen/Utils.h"
1414
#include "ATen/Device.h"
1515
#include "ATen/Layout.h"
16+
#include "ATen/optional.h"
1617

1718
namespace at {
1819
struct Type;
@@ -42,6 +43,7 @@ namespace at {
4243
// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
4344
// special care must be taken to handle this.
4445
struct Tensor : public detail::TensorBase {
46+
using TensorBase = detail::TensorBase;
4547
Tensor() : TensorBase() {}
4648
Tensor(TensorImpl * self, bool retain) : TensorBase(self, retain) {}
4749
Tensor(const TensorBase & rhs) : TensorBase(rhs) {}
@@ -198,6 +200,46 @@ struct Tensor : public detail::TensorBase {
198200
auto m(F func, Args&&... params) const -> decltype(func(*this, std::forward<Args>(params)...)) {
199201
return func(*this, std::forward<Args>(params)...);
200202
}
203+
204+
friend struct WeakTensor;
205+
};
206+
207+
struct WeakTensor : public detail::WeakTensorBase {
208+
using WeakTensorBase = detail::WeakTensorBase;
209+
WeakTensor() : WeakTensorBase() {}
210+
WeakTensor(TensorImpl * self, bool retain) : WeakTensorBase(self, retain) {}
211+
WeakTensor(const WeakTensor & rhs) = default;
212+
WeakTensor(WeakTensor && rhs) noexcept = default;
213+
WeakTensor(const Tensor& t) : WeakTensorBase(t.pImpl, true) {}
214+
215+
// reimplemented from TensorBase so the return type is WeakTensor rather than TensorBase
216+
WeakTensor & operator=(WeakTensor && rhs) & {
217+
rhs.swap(*this);
218+
return *this;
219+
}
220+
WeakTensor & operator=(WeakTensor const & rhs) & {
221+
//Tensor ctor retains original rhs.pImpl
222+
//then rhs.pImpl is swapped with this->pImpl
223+
//finally Tensor dtor releases rhs.pImpl, which was originally this->pImpl
224+
WeakTensor(rhs).swap(*this);
225+
return *this;
226+
}
227+
228+
WeakTensor & operator=(const Tensor& t) {
229+
WeakTensor(t.pImpl, true).swap(*this);
230+
return *this;
231+
}
232+
233+
// non-retaining
234+
TensorImpl * unsafeGetTensorImpl() const {
235+
return pImpl;
236+
}
237+
238+
// XXX: this can return undefined tensors
239+
// Ideally it would be at::optional<Tensor>, but MSVC is too cool for that
240+
Tensor lock() const {
241+
return pImpl->weak_lock() ? Tensor(pImpl, false) : Tensor();
242+
}
201243
};
202244

203245
namespace detail {

aten/src/ATen/templates/TensorDerived.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ void * ${Tensor}::unsafeGetTH(bool retain) {
4949
return tensor;
5050
}
5151

52+
void ${Tensor}::release_resources() {
53+
${THTensor}_free(${state,} tensor);
54+
tensor = nullptr;
55+
}
56+
5257
${TensorDenseOrSparse}
5358

5459
}

aten/src/ATen/templates/TensorDerived.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct ${Tensor} final : public TensorImpl {
2323
virtual Scalar localScalar() override;
2424
virtual void * unsafeGetTH(bool retain) override;
2525
virtual std::unique_ptr<Storage> storage() override;
26+
virtual void release_resources() override;
2627
static const char * typeString();
2728

2829
//TODO(zach): sort of friend permissions later so this

aten/src/ATen/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ list(APPEND ATen_CPU_TEST_SRCS
1818
${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
1919
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
2020
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
21-
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp)
21+
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
22+
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp)
2223

2324
list(APPEND ATen_CUDA_TEST_SRCS
2425
${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu

0 commit comments

Comments
 (0)