Skip to content

Commit bbb0131

Browse files
ezyangfacebook-github-bot
authored andcommitted
Add weak pointer and finalizer support directly to THStorage. (#9148)
Summary: The underlying use-case is the file descriptor to storage cache in torch.multiprocessing.reductions. Previously, this was implemented by wrapping an existing allocator with a "weak ref" allocator which also knew to null out the weak reference when the storage died. This is terribly oblique, and prevents us from refactoring the allocators to get rid of per-storage allocator state. So instead of going through this fiasco, we instead directly implement weak pointers and finalizers in THStorage. Weak pointers to THStorage retain the THStorage struct, but not the data_ptr. When all strong references die, data_ptr dies and the finalizers get invoked. There is one major hazard in this patch, which is what happens if you repeatedly call _weak_ref on a storage. For cleanliness, we no longer shove our grubby fingers into the finalizer struct to see if there is already a Python object for the weak reference and return it; we just create a new one (no one is checking these Python objects for identity). This means if you keep calling it, we'll keep piling on finalizers. That's bad! But I am not going to fix it until it is actually a problem for someone, because then we need to add another caching layer. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch/pytorch#9148 Differential Revision: D8729106 Pulled By: ezyang fbshipit-source-id: 69710ca3b7c7e05069090e1b263f8b6b9f1cf72f
1 parent d908161 commit bbb0131

File tree

8 files changed

+93
-51
lines changed

8 files changed

+93
-51
lines changed

aten/src/TH/THStorage.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,56 @@
1414
#include "generic/THStorageCopy.cpp"
1515
#include "THGenerateHalfType.h"
1616

17+
// Free a non-weak pointer to THStorage
1718
void THStorage_free(THStorage *storage) {
1819
AT_ASSERT(storage->backend == at::kCPU);
1920

2021
if (!storage) {
2122
return;
2223
}
2324

24-
if ((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount.load() > 0)) {
25+
if (storage->flag & TH_STORAGE_REFCOUNTED) {
2526
if (--storage->refcount == 0) {
27+
if (storage->finalizer) {
28+
(*storage->finalizer)();
29+
}
30+
storage->finalizer.~unique_ptr<THFinalizer>();
2631
if (storage->flag & TH_STORAGE_FREEMEM) {
2732
static_cast<THAllocator*>(storage->allocatorVoidPtr)->free(storage->allocatorContext, storage->data_ptr);
2833
}
2934
if (storage->flag & TH_STORAGE_VIEW) {
3035
THStorage_free(storage->view);
3136
}
32-
storage->refcount.~atomic<int>();
33-
THFree(storage);
37+
THStorage_weakFree(storage);
3438
}
3539
}
3640
}
3741

42+
// Manually retains a weak reference
43+
void THStorage_weakRetain(THStorage *weak_storage) {
44+
weak_storage->weakcount++;
45+
}
46+
47+
// Releases a weak reference
48+
void THStorage_weakFree(THStorage *weak_storage) {
49+
if (--weak_storage->weakcount == 0) {
50+
weak_storage->refcount.~atomic<int>();
51+
weak_storage->weakcount.~atomic<int>();
52+
THFree(weak_storage);
53+
}
54+
}
55+
56+
// Given a weak reference, returns a strong reference to a storage (which must
57+
// be freed when done) or null if the storage is already dead.
58+
THStorage* THStorage_weakLock(THStorage *weak_storage) {
59+
for (;;) {
60+
int refcount = weak_storage->refcount.load();
61+
if (refcount == 0) return nullptr;
62+
if (weak_storage->refcount.compare_exchange_strong(refcount, refcount + 1)) break;
63+
}
64+
return weak_storage;
65+
}
66+
3867
THDescBuff THLongStorage_sizeDesc(const THLongStorage *size) {
3968
return _THSizeDesc(THLongStorage_data(size), size->size);
4069
}
@@ -89,6 +118,8 @@ THStorage* THStorage_newWithAllocator(at::ScalarType scalar_type, ptrdiff_t size
89118
storage->data_ptr = allocator->malloc(allocatorContext, at::elementSize(scalar_type)*size);
90119
storage->size = size;
91120
new (&storage->refcount) std::atomic<int>(1);
121+
new (&storage->weakcount) std::atomic<int>(1); // from the strong reference
122+
new (&storage->finalizer) std::unique_ptr<THFinalizer>(nullptr);
92123
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
93124
storage->allocatorVoidPtr = allocator;
94125
storage->allocatorContext = allocatorContext;
@@ -140,19 +171,6 @@ void THStorage_retain(THStorage *storage)
140171
}
141172
}
142173

143-
int THStorage_retainIfLive(THStorage *storage)
144-
{
145-
// TODO: Check if TH_STORAGE_REFCOUNTED?
146-
int refcount = storage->refcount.load();
147-
while (refcount > 0) {
148-
if (storage->refcount.compare_exchange_strong(refcount, refcount + 1)) {
149-
return 1;
150-
}
151-
refcount = storage->refcount.load();
152-
}
153-
return 0;
154-
}
155-
156174
THStorage* THStorage_newWithData(at::ScalarType scalar_type, void *data, ptrdiff_t size)
157175
{
158176
return THStorage_newWithDataAndAllocator(scalar_type, data, size,
@@ -168,7 +186,9 @@ THStorage* THStorage_newWithDataAndAllocator(at::ScalarType scalar_type,
168186
storage->scalar_type = scalar_type;
169187
storage->data_ptr = data;
170188
storage->size = size;
171-
storage->refcount = 1;
189+
new (&storage->refcount) std::atomic<int>(1);
190+
new (&storage->weakcount) std::atomic<int>(1); // from the strong reference
191+
new (&storage->finalizer) std::unique_ptr<THFinalizer>(nullptr);
172192
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
173193
storage->allocatorVoidPtr = allocator;
174194
storage->allocatorContext = allocatorContext;
@@ -227,6 +247,7 @@ void THStorage_swap(THStorage *storage1, THStorage *storage2)
227247
// don't swap refcount!
228248
SWAP(allocatorVoidPtr);
229249
SWAP(allocatorContext);
250+
SWAP(finalizer);
230251
SWAP(view);
231252
SWAP(device);
232253
#undef SWAP

aten/src/TH/THStorage.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
// This exists to have a data-type independent way of freeing (necessary for THPPointer).
2121
TH_API void THStorage_free(THStorage *storage);
22+
TH_API void THStorage_weakFree(THStorage *storage);
2223

2324
TH_API THDescBuff THLongStorage_sizeDesc(const THLongStorage *size);
2425
TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement);

aten/src/TH/THStorage.hpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,45 @@
1010
#include "THTypeConversion.hpp"
1111
#include <atomic>
1212

13+
// Note [Weak references for intrusive refcounting]
14+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15+
// Here's the scheme:
16+
//
17+
// - refcount == number of strong references to the object
18+
// weakcount == number of weak references to the object,
19+
// plus one more if refcount > 0
20+
//
21+
// - THStorage stays live as long as there are any strong
22+
// or weak pointers to it (weakcount > 0, since strong
23+
// references count as a +1 to weakcount)
24+
//
25+
// - finalizers are called and data_ptr is deallocated when refcount == 0
26+
//
27+
// - Once refcount == 0, it can never again be > 0 (the transition
28+
// from > 0 to == 0 is monotonic)
29+
//
30+
// - When you access THStorage via a weak pointer, you must
31+
// atomically increment the use count, if it is greater than 0.
32+
// If it is not, you must report that the storage is dead.
33+
//
34+
35+
struct THFinalizer {
36+
virtual void operator()() = 0;
37+
virtual ~THFinalizer() {};
38+
};
39+
1340
typedef struct THStorage
1441
{
1542
at::Backend backend; // kCPU or kCUDA only
1643
at::ScalarType scalar_type;
1744
void *data_ptr;
1845
ptrdiff_t size;
1946
std::atomic<int> refcount;
47+
std::atomic<int> weakcount;
2048
char flag;
2149
void *allocatorVoidPtr; // Either THDeviceAllocator or THCDeviceAllocator
2250
void *allocatorContext;
51+
std::unique_ptr<THFinalizer> finalizer;
2352
struct THStorage *view;
2453
int device;
2554

@@ -51,11 +80,14 @@ THStorage* THStorage_newWithMapping(at::ScalarType scalar_type, const char *file
5180
void THStorage_setFlag(THStorage *storage, const char flag);
5281
void THStorage_clearFlag(THStorage *storage, const char flag);
5382
void THStorage_retain(THStorage *storage);
54-
int THStorage_retainIfLive(THStorage *storage);
5583
THStorage* THStorage_newWithData(at::ScalarType scalar_type, void *data, ptrdiff_t size);
5684
THStorage* THStorage_newWithDataAndAllocator(at::ScalarType scalar_type,
5785
void* data, ptrdiff_t size,
5886
THAllocator* allocator,
5987
void* allocatorContext);
6088
void THStorage_resize(THStorage *storage, ptrdiff_t size);
6189
void THStorage_swap(THStorage *storage1, THStorage *storage2);
90+
91+
void THStorage_weakRetain(THStorage *weak_storage);
92+
void THStorage_weakFree(THStorage *weak_storage);
93+
THStorage* THStorage_weakLock(THStorage *weak_storage);

aten/src/TH/generic/THStorage.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,6 @@ void THStorage_(retain)(THStorage *storage)
9595
THStorage_retain(storage);
9696
}
9797

98-
int THStorage_(retainIfLive)(THStorage *storage)
99-
{
100-
return THStorage_retainIfLive(storage);
101-
}
102-
10398
void THStorage_(free)(THStorage *storage)
10499
{
105100
THStorage_free(storage);

aten/src/TH/generic/THStorage.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ TH_API void THStorage_(clearFlag)(THStorage *storage, const char flag);
6666
TH_API void THStorage_(retain)(THStorage *storage);
6767
TH_API void THStorage_(swap)(THStorage *storage1, THStorage *storage2);
6868

69-
/* used by StorageSharing */
70-
TH_API int THStorage_(retainIfLive)(THStorage *storage);
71-
7269
/* might differ with other API (like CUDA) */
7370
TH_API void THStorage_(free)(THStorage *storage);
7471
TH_API void THStorage_(resize)(THStorage *storage, ptrdiff_t size);

aten/src/THC/THCStorage.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ THCStorage* THCStorage_newWithAllocator(THCState *state,
3434
THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage));
3535
memset(storage, 0, sizeof(THCStorage));
3636
new (&storage->refcount) std::atomic<int>(1);
37+
new (&storage->weakcount) std::atomic<int>(1);
38+
new (&storage->finalizer) std::unique_ptr<THFinalizer>(nullptr);
3739
storage->backend = at::kCUDA;
3840
storage->scalar_type = scalar_type;
3941
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
@@ -60,25 +62,25 @@ THCStorage* THCStorage_newWithAllocator(THCState *state,
6062
return storage;
6163
}
6264

63-
void THCStorage_free(THCState *state, THCStorage *self)
65+
void THCStorage_free(THCState *state, THCStorage *storage)
6466
{
65-
AT_ASSERT(self->backend == at::kCUDA);
66-
67-
if(!(self->flag & TH_STORAGE_REFCOUNTED))
68-
return;
67+
AT_ASSERT(storage->backend == at::kCUDA);
6968

70-
if (--self->refcount == 0)
71-
{
72-
if(self->flag & TH_STORAGE_FREEMEM) {
73-
auto* thc_device_allocator = static_cast<THCDeviceAllocator*>(self->allocatorVoidPtr);
74-
THCudaCheck(
75-
(*thc_device_allocator->free)(self->allocatorContext, self->data_ptr));
76-
}
77-
if(self->flag & TH_STORAGE_VIEW) {
78-
THCStorage_free(state, self->view);
69+
if (storage->flag & TH_STORAGE_REFCOUNTED) {
70+
if (--storage->refcount == 0) {
71+
if (storage->finalizer) {
72+
(*storage->finalizer)();
73+
}
74+
storage->finalizer.~unique_ptr<THFinalizer>();
75+
if (storage->flag & TH_STORAGE_FREEMEM) {
76+
auto* thc_device_allocator = static_cast<THCDeviceAllocator*>(storage->allocatorVoidPtr);
77+
THCudaCheck((*thc_device_allocator->free)(storage->allocatorContext, storage->data_ptr));
78+
}
79+
if (storage->flag & TH_STORAGE_VIEW) {
80+
THCStorage_free(state, storage->view);
81+
}
82+
THStorage_weakFree(storage);
7983
}
80-
self->refcount.~atomic<int>();
81-
THFree(self);
8284
}
8385
}
8486

@@ -174,7 +176,9 @@ THCStorage* THCStorage_newWithDataAndAllocator(
174176
storage->scalar_type = scalar_type;
175177
storage->data_ptr = data;
176178
storage->size = size;
177-
storage->refcount = 1;
179+
new (&storage->refcount) std::atomic<int>(1);
180+
new (&storage->weakcount) std::atomic<int>(1);
181+
new (&storage->finalizer) std::unique_ptr<THFinalizer>(nullptr);
178182
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
179183
storage->allocatorVoidPtr = allocator;
180184
storage->allocatorContext = allocatorContext;

aten/src/THC/generic/THCStorage.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,6 @@ void THCStorage_(retain)(THCState *state, THCStorage *self)
122122
THStorage_retain(self);
123123
}
124124

125-
int THCStorage_(retainIfLive)(THCState *state, THCStorage *storage)
126-
{
127-
return THStorage_retainIfLive(storage);
128-
}
129-
130125
void THCStorage_(free)(THCState *state, THCStorage *self)
131126
{
132127
THCStorage_free(state, self);

aten/src/THC/generic/THCStorage.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ THC_API void THCStorage_(setFlag)(THCState *state, THCStorage *storage, const ch
5353
THC_API void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char flag);
5454
THC_API void THCStorage_(retain)(THCState *state, THCStorage *storage);
5555

56-
/* used by StorageSharing */
57-
THC_API int THCStorage_(retainIfLive)(THCState *state, THCStorage *storage);
58-
5956
THC_API void THCStorage_(free)(THCState *state, THCStorage *storage);
6057
THC_API void THCStorage_(resize)(THCState *state, THCStorage *storage, ptrdiff_t size);
6158
THC_API void THCStorage_(fill)(THCState *state, THCStorage *storage, real value);

0 commit comments

Comments
 (0)