Skip to content

Commit 45d03a5

Browse files
ezyangRob Kunkle
authored and
Rob Kunkle
committed
Eliminate storage views. (pytorch#9466)
Summary: Storage views were previously used to implement CUDA IPC sharing, but they weren't necessary. The new strategy is described in Note [CUDA IPC and the caching allocator]. This also fixes an unrelated bug, where we weren't actually using the Tensor forking pickler, because we didn't register a pickler for torch.Tensor. Fixes pytorch#9447. Fixes pytorch#46. Signed-off-by: Edward Z. Yang <[email protected]> CC apaszke Pull Request resolved: pytorch#9466 Reviewed By: apaszke Differential Revision: D8859698 Pulled By: ezyang fbshipit-source-id: 3362cb92f6ae4aa37084c57d79b31004bd0b4a97
1 parent e66ad2e commit 45d03a5

File tree

12 files changed

+139
-193
lines changed

12 files changed

+139
-193
lines changed

aten/src/TH/THStorage.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ void THStorage_free(THStorage *storage) {
2727
}
2828
storage->finalizer.~unique_ptr<THFinalizer>();
2929
storage->data_ptr.~DataPtr();
30-
if (storage->flag & TH_STORAGE_VIEW) {
31-
THStorage_free(storage->view);
32-
}
3330
THStorage_weakFree(storage);
3431
}
3532
}
@@ -227,6 +224,5 @@ void THStorage_swap(THStorage *storage1, THStorage *storage2)
227224
SWAP(flag);
228225
SWAP(allocator);
229226
SWAP(finalizer);
230-
SWAP(view);
231227
#undef SWAP
232228
}

aten/src/TH/THStorage.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ typedef struct THStorage
4747
char flag;
4848
at::Allocator *allocator;
4949
std::unique_ptr<THFinalizer> finalizer;
50-
struct THStorage *view;
5150

5251
template <typename T>
5352
inline T * data() const {

aten/src/TH/generic/THStorage.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
#define TH_STORAGE_REFCOUNTED 1
2424
#define TH_STORAGE_RESIZABLE 2
25-
#define TH_STORAGE_VIEW 8
2625

2726
// Struct definition is moved to THStorage.hpp (so this file stays C compatible)
2827
typedef struct THStorage THStorage;

aten/src/THC/THCStorage.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ void THCStorage_free(THCState *state, THCStorage *storage)
5555
}
5656
storage->finalizer.~unique_ptr<THFinalizer>();
5757
storage->data_ptr.~DataPtr();
58-
if (storage->flag & TH_STORAGE_VIEW) {
59-
THCStorage_free(state, storage->view);
60-
}
6158
THStorage_weakFree(storage);
6259
}
6360
}

test/test_multiprocessing.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,18 +207,15 @@ def test_receive():
207207
def _test_preserve_sharing(self, ctx=mp, repeat=1):
208208
def do_test():
209209
x = torch.randn(5, 5)
210-
data = [x.storage(), x.storage()[1:4], x, x[2], x[:, 1]]
210+
data = [x.storage(), x, x[2], x[:, 1]]
211211
q = ctx.Queue()
212212
q.put(data)
213213
new_data = q.get(timeout=1)
214214
self.assertEqual(new_data, data, 0)
215215
storage_cdata = data[0]._cdata
216216
self.assertEqual(new_data[0]._cdata, storage_cdata)
217-
for t in new_data[2:]:
217+
for t in new_data[1:]:
218218
self.assertEqual(t.storage()._cdata, storage_cdata)
219-
# TODO: enable after fixing #46
220-
# new_data[0].fill_(10)
221-
# self.assertEqual(new_data[1], new_data[0][1:4], 0)
222219

223220
with leak_checker(self):
224221
for i in range(repeat):
@@ -335,7 +332,13 @@ def test_cuda_small_tensors(self):
335332
self.assertEqual(v, torch.arange(i * 5., (i + 1) * 5).sum())
336333
self.assertEqual(device, i % 2)
337334
self.assertEqual(tensor_size, 5)
338-
self.assertEqual(storage_size, 5)
335+
# You might think this should be the case, but it's not! After
336+
# data from the CUDA caching allocator goes through IPC, the
337+
# size of the storage is the size of the *cached cudaMalloc for
338+
# the entire memory block* of the storage, not just the storage.
339+
# See Note [CUDA IPC and the caching allocator] for more info
340+
#
341+
# self.assertEqual(storage_size, 5)
339342

340343
@unittest.skipIf(IS_WINDOWS, 'not applicable to Windows (only fails with fork)')
341344
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')

test/test_torch.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6456,17 +6456,6 @@ def test_storage(self):
64566456
self.assertEqual(v.storage()[0], v.data[0][0])
64576457
self.assertEqual(v.storage()[14], v.data[2][4])
64586458

6459-
def test_storageview(self):
6460-
s1 = torch.LongStorage((3, 4, 5))
6461-
s2 = torch.LongStorage(s1, 1)
6462-
6463-
self.assertEqual(s2.size(), 2)
6464-
self.assertEqual(s2[0], s1[1])
6465-
self.assertEqual(s2[1], s1[2])
6466-
6467-
s2[1] = 13
6468-
self.assertEqual(13, s1[2])
6469-
64706459
def test_nonzero(self):
64716460
num_src = 12
64726461

@@ -6732,14 +6721,14 @@ def test_parsing_intlist(self):
67326721

67336722
def _test_serialization_data(self):
67346723
a = [torch.randn(5, 5).float() for i in range(2)]
6735-
b = [a[i % 2] for i in range(4)]
6736-
b += [a[0].storage()]
6737-
b += [a[0].storage()[1:4]]
6738-
b += [torch.arange(1, 11).int()]
6739-
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
6740-
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
6741-
b += [(t1.storage(), t1.storage(), t2.storage())]
6742-
b += [a[0].storage()[0:2]]
6724+
b = [a[i % 2] for i in range(4)] # 0-3
6725+
b += [a[0].storage()] # 4
6726+
b += [a[0].reshape(-1)[1:4].storage()] # 5
6727+
b += [torch.arange(1, 11).int()] # 6
6728+
t1 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,))
6729+
t2 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,))
6730+
b += [(t1.storage(), t1.storage(), t2.storage())] # 7
6731+
b += [a[0].reshape(-1)[0:2].storage()] # 8
67436732
return b
67446733

67456734
def _test_serialization_assert(self, b, c):
@@ -6754,7 +6743,10 @@ def _test_serialization_assert(self, b, c):
67546743
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
67556744
c[1].fill_(20)
67566745
self.assertEqual(c[1], c[3], 0)
6757-
self.assertEqual(c[4][1:4], c[5], 0)
6746+
# I have to do it in this roundabout fashion, because there's no
6747+
# way to slice storages
6748+
for i in range(4):
6749+
self.assertEqual(c[4][i + 1], c[5][i])
67586750

67596751
# check that serializing the same storage view object unpickles
67606752
# it as one object not two (and vice versa)
@@ -6914,7 +6906,7 @@ def test_serialization_backwards_compat(self):
69146906
a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
69156907
b = [a[i % 2] for i in range(4)]
69166908
b += [a[0].storage()]
6917-
b += [a[0].storage()[1:4]]
6909+
b += [a[0].reshape(-1)[1:4].clone().storage()]
69186910
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
69196911
c = torch.load(path)
69206912
self.assertEqual(b, c, 0)
@@ -6928,7 +6920,6 @@ def test_serialization_backwards_compat(self):
69286920
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
69296921
c[1].fill_(20)
69306922
self.assertEqual(c[1], c[3], 0)
6931-
self.assertEqual(c[4][1:4], c[5], 0)
69326923

69336924
# test some old tensor serialization mechanism
69346925
class OldTensorBase(object):

torch/csrc/generic/Storage.cpp

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -88,44 +88,8 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
8888

8989
// torch.Storage(view_source, [offset, [size]])
9090
if (num_args < 4 && THPStorage_(Check)(first_arg)) {
91-
#ifdef THD_GENERIC_FILE
92-
THPUtils_setError("distributed storages don't support storage views");
91+
THPUtils_setError("storage views not supported");
9392
return NULL;
94-
#else
95-
THPStorage *storage_arg = (THPStorage *)first_arg;
96-
int64_t numel = storage_arg->cdata->size;
97-
int64_t offset = 0;
98-
99-
if (num_args >= 2) {
100-
PyObject *second_arg = PyTuple_GET_ITEM(args, 1);
101-
if (!THPUtils_checkLong(second_arg))
102-
goto invalid_arguments;
103-
offset = THPUtils_unpackLong(second_arg);
104-
}
105-
106-
int64_t size = numel - offset;
107-
if (num_args >= 3) {
108-
PyObject *third_arg = PyTuple_GET_ITEM(args, 2);
109-
if (!THPUtils_checkLong(third_arg))
110-
goto invalid_arguments;
111-
size = THPUtils_unpackLong(third_arg);
112-
}
113-
114-
THPUtils_assert(offset >= 0 && offset <= numel, "specified an offset of "
115-
"%" PRId64 ", but the viewed storage has only %" PRId64 " element(s)", offset, numel);
116-
THPUtils_assert(size >= 1 && size <= numel - offset, "specified a size of "
117-
"%" PRId64 ", but the viewed storage has only %" PRId64 " element(s) after offset %" PRId64,
118-
size, numel - offset, offset);
119-
120-
real *data_ptr = THWStorage_(data)(LIBRARY_STATE storage_arg->cdata) + offset;
121-
// TODO: Hmmmm
122-
THWStoragePtr storage(THWStorage_(newWithDataAndAllocator)(LIBRARY_STATE {data_ptr, storage_arg->cdata->data_ptr.device()} /* non-owning */, size, nullptr));
123-
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
124-
storage->view = storage_arg->cdata;
125-
THWStorage_(retain)(LIBRARY_STATE storage_arg->cdata);
126-
self->cdata = storage.release();
127-
return (PyObject*)self.release();
128-
#endif
12993
}
13094

13195
// torch.Storage(sequence)
@@ -161,9 +125,6 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
161125
#endif
162126
}
163127

164-
#ifndef THD_GENERIC_FILE
165-
invalid_arguments:
166-
#endif
167128
THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
168129
"no arguments",
169130
"(int size)",
@@ -199,30 +160,8 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
199160
return THPUtils_(newReal)(value);
200161
/* Slice index */
201162
} else if (PySlice_Check(index)) {
202-
#ifdef THD_GENERIC_FILE
203-
THPUtils_setError("distributed storages don't support slicing");
163+
THPUtils_setError("storages don't support slicing");
204164
return NULL;
205-
#else
206-
Py_ssize_t start, stop, slicelength, step;
207-
int64_t len = THWStorage_(size)(LIBRARY_STATE self->cdata);
208-
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
209-
return NULL;
210-
if (step != 1) {
211-
THPUtils_setError("Trying to slice with a step of %" PRId64 ", but only a step of "
212-
"1 is supported", (int64_t)step);
213-
return NULL;
214-
}
215-
216-
real *data = THWStorage_(data)(LIBRARY_STATE self->cdata);
217-
THWStoragePtr new_storage(THWStorage_(newWithDataAndAllocator)(LIBRARY_STATE {static_cast<void*>(data + start), self->cdata->data_ptr.device()} /* non-owning */, slicelength, nullptr));
218-
new_storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
219-
new_storage->view = self->cdata;
220-
THWStorage_(retain)(LIBRARY_STATE self->cdata);
221-
222-
PyObject *_ret = THPStorage_(New)(new_storage);
223-
new_storage.release();
224-
return _ret;
225-
#endif
226165
}
227166
PyErr_Format(PyExc_TypeError, "can't index a " THPStorageStr " with %s",
228167
THPUtils_typename(index));

torch/csrc/generic/StorageMethods.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,26 +292,6 @@ PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata)
292292
END_HANDLE_TH_ERRORS
293293
}
294294

295-
#ifndef THD_GENERIC_FILE
296-
PyObject * THPStorage_(_rootStorage)(THPStorage *self)
297-
{
298-
HANDLE_TH_ERRORS
299-
if (!(self->cdata->flag & TH_STORAGE_VIEW)) {
300-
return Py_BuildValue("(ON)", self, PyLong_FromLong(0));
301-
}
302-
THWStorage *root = self->cdata;
303-
while (root->flag & TH_STORAGE_VIEW)
304-
root = root->view;
305-
size_t offset = THWStorage_(data)(LIBRARY_STATE self->cdata) - THWStorage_(data)(LIBRARY_STATE root);
306-
THWStorage_(retain)(LIBRARY_STATE root);
307-
THPObjectPtr storage(THPStorage_(New)(root));
308-
PyObject *result = Py_BuildValue("(NN)", storage.get(), PyLong_FromLong(offset));
309-
storage.release();
310-
return result;
311-
END_HANDLE_TH_ERRORS
312-
}
313-
#endif
314-
315295
static PyMethodDef THPStorage_(methods)[] = {
316296
{"copy_", (PyCFunction)THPStorage_(copy_), METH_VARARGS | METH_KEYWORDS, NULL},
317297
{"element_size", (PyCFunction)THPStorage_(elementSize), METH_NOARGS, NULL},
@@ -335,7 +315,6 @@ static PyMethodDef THPStorage_(methods)[] = {
335315
#endif
336316
{"_set_cdata", (PyCFunction)THPStorage_(_setCdata), METH_O, NULL},
337317
#ifndef THD_GENERIC_FILE
338-
{"_root_storage", (PyCFunction)THPStorage_(_rootStorage), METH_NOARGS, NULL},
339318
#endif
340319
{NULL}
341320
};

0 commit comments

Comments
 (0)