Skip to content

Commit da5bb37

Browse files
committed
Type conversions now use auto gpu
1 parent 6584b35 commit da5bb37

File tree

8 files changed

+161
-55
lines changed

8 files changed

+161
-55
lines changed

test/test_cuda.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,27 @@ def test_serialization(self):
280280
q_copy[1].fill_(10)
281281
self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
282282

283+
def test_type_conversions(self):
284+
x = torch.randn(5, 5)
285+
self.assertIs(type(x.float()), torch.FloatTensor)
286+
self.assertIs(type(x.cuda()), torch.cuda.DoubleTensor)
287+
self.assertIs(type(x.cuda().float()), torch.cuda.FloatTensor)
288+
self.assertIs(type(x.cuda().float().cpu()), torch.FloatTensor)
289+
self.assertIs(type(x.cuda().float().cpu().int()), torch.IntTensor)
290+
291+
y = x.storage()
292+
self.assertIs(type(y.float()), torch.FloatStorage)
293+
self.assertIs(type(y.cuda()), torch.cuda.DoubleStorage)
294+
self.assertIs(type(y.cuda().float()), torch.cuda.FloatStorage)
295+
self.assertIs(type(y.cuda().float().cpu()), torch.FloatStorage)
296+
self.assertIs(type(y.cuda().float().cpu().int()), torch.IntStorage)
297+
298+
@unittest.skipIf(torch.cuda.deviceCount() < 2, "only one GPU detected")
299+
def test_type_conversions_same_gpu(self):
300+
x = torch.randn(5, 5).cuda(1)
301+
self.assertEqual(x.int().getDevice(), 1)
302+
303+
283304
for decl in tests:
284305
for t in types:
285306
tensor = t()

torch/Storage.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from ._utils import _type
23

34

45
class _StorageBase(object):
@@ -35,3 +36,27 @@ def clone(self):
3536

3637
def tolist(self):
3738
return [v for v in self]
39+
40+
def double(self):
41+
return self.type(type(self).__module__ + '.DoubleStorage')
42+
43+
def float(self):
44+
return self.type(type(self).__module__ + '.FloatStorage')
45+
46+
def long(self):
47+
return self.type(type(self).__module__ + '.LongStorage')
48+
49+
def int(self):
50+
return self.type(type(self).__module__ + '.IntStorage')
51+
52+
def short(self):
53+
return self.type(type(self).__module__ + '.ShortStorage')
54+
55+
def char(self):
56+
return self.type(type(self).__module__ + '.CharStorage')
57+
58+
def byte(self):
59+
return self.type(type(self).__module__ + '.ByteStorage')
60+
61+
62+
_StorageBase.type = _type

torch/Tensor.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from . import TensorPrinting
3+
from ._utils import _type
34
from functools import reduce
45
from itertools import chain
56
import sys
@@ -25,43 +26,29 @@ class _TensorBase(object):
2526
def new(self, *args, **kwargs):
2627
return self.__class__(*args, **kwargs)
2728

28-
def type(self, t=None):
29-
if isinstance(t, str) or t is None:
30-
current = self.__module__ + '.' + self.__class__.__name__
31-
if t is None:
32-
return current
33-
if t == current:
34-
return self
35-
_, _, typename = t.partition('.')
36-
return torch._import_dotted_name(t)(self.size()).copy_(self)
37-
else:
38-
if t == type(self):
39-
return self
40-
return t(self.size()).copy_(self)
41-
4229
def typeAs(self, t):
4330
return self.type(t.type())
4431

4532
def double(self):
46-
return self.type(torch.DoubleTensor)
33+
return self.type(type(self).__module__ + '.DoubleTensor')
4734

4835
def float(self):
49-
return self.type(torch.FloatTensor)
36+
return self.type(type(self).__module__ + '.FloatTensor')
5037

5138
def long(self):
52-
return self.type(torch.LongTensor)
39+
return self.type(type(self).__module__ + '.LongTensor')
5340

5441
def int(self):
55-
return self.type(torch.IntTensor)
42+
return self.type(type(self).__module__ + '.IntTensor')
5643

5744
def short(self):
58-
return self.type(torch.ShortTensor)
45+
return self.type(type(self).__module__ + '.ShortTensor')
5946

6047
def char(self):
61-
return self.type(torch.CharTensor)
48+
return self.type(type(self).__module__ + '.CharTensor')
6249

6350
def byte(self):
64-
return self.type(torch.ByteTensor)
51+
return self.type(type(self).__module__ + '.ByteTensor')
6552

6653
def copy_(self, other):
6754
torch._C._tensorCopy(self, other)
@@ -267,3 +254,6 @@ def __mod__(self, other):
267254

268255
def __neg__(self):
269256
return self.neg()
257+
258+
259+
_TensorBase.type = _type

torch/_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
4+
def _type(self, t=None):
5+
if isinstance(t, str) or t is None:
6+
current = self.__module__ + '.' + self.__class__.__name__
7+
if t is None:
8+
return current
9+
if t == current:
10+
return self
11+
_, _, typename = t.partition('.')
12+
return torch._import_dotted_name(t)(self.size()).copy_(self)
13+
else:
14+
if t == type(self):
15+
return self
16+
return t(self.size()).copy_(self)
17+

torch/csrc/generic/StorageMethods.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,15 @@ PyObject * THPStorage_(_sharedFd)(THPStorage *self)
405405
}
406406
#endif
407407

408+
#ifdef THC_GENERIC_FILE
409+
PyObject * THPStorage_(getDevice)(THPStorage *self)
410+
{
411+
HANDLE_TH_ERRORS
412+
return PyLong_FromLong(THCStorage_(getDevice)(LIBRARY_STATE self->cdata));
413+
END_HANDLE_TH_ERRORS
414+
}
415+
#endif
416+
408417
PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata)
409418
{
410419
HANDLE_TH_ERRORS
@@ -438,6 +447,9 @@ static PyMethodDef THPStorage_(methods)[] = {
438447
{"_get_shared_fd", (PyCFunction)THPStorage_(_sharedFd), METH_NOARGS, NULL},
439448
{"_shared_decref", (PyCFunction)THPStorage_(_sharedDecref), METH_NOARGS, NULL},
440449
{"_shared_incref", (PyCFunction)THPStorage_(_sharedIncref), METH_NOARGS, NULL},
450+
#endif
451+
#ifdef THC_GENERIC_FILE
452+
{"getDevice", (PyCFunction)THPStorage_(getDevice), METH_NOARGS, NULL},
441453
#endif
442454
{"_set_cdata", (PyCFunction)THPStorage_(_setCdata), METH_O, NULL},
443455
{NULL}

torch/cuda/__init__.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,56 +24,76 @@
2424
"for old NVIDIA drivers")
2525

2626

27-
from torch.Storage import _StorageBase
28-
from torch.Tensor import _TensorBase
27+
@contextlib.contextmanager
28+
def device(idx):
29+
prev_idx = torch._C._cuda_getDevice()
30+
torch._C._cuda_setDevice(idx)
31+
yield
32+
torch._C._cuda_setDevice(prev_idx)
33+
34+
35+
@contextlib.contextmanager
36+
def _dummy_ctx():
37+
yield
38+
39+
40+
def deviceCount():
41+
return torch._C._cuda_getDeviceCount()
42+
2943

3044
################################################################################
3145
# Define Storage and Tensor classes
3246
################################################################################
3347

34-
class DoubleStorage(torch._C.CudaDoubleStorageBase, _StorageBase):
48+
49+
from .tensor import _CudaTensorBase
50+
from .storage import _CudaStorageBase
51+
52+
53+
class DoubleStorage(torch._C.CudaDoubleStorageBase, _CudaStorageBase):
3554
pass
36-
class FloatStorage(torch._C.CudaFloatStorageBase, _StorageBase):
55+
class FloatStorage(torch._C.CudaFloatStorageBase, _CudaStorageBase):
3756
pass
38-
class LongStorage(torch._C.CudaLongStorageBase, _StorageBase):
57+
class LongStorage(torch._C.CudaLongStorageBase, _CudaStorageBase):
3958
pass
40-
class IntStorage(torch._C.CudaIntStorageBase, _StorageBase):
59+
class IntStorage(torch._C.CudaIntStorageBase, _CudaStorageBase):
4160
pass
42-
class ShortStorage(torch._C.CudaShortStorageBase, _StorageBase):
61+
class ShortStorage(torch._C.CudaShortStorageBase, _CudaStorageBase):
4362
pass
44-
class CharStorage(torch._C.CudaCharStorageBase, _StorageBase):
63+
class CharStorage(torch._C.CudaCharStorageBase, _CudaStorageBase):
4564
pass
46-
class ByteStorage(torch._C.CudaByteStorageBase, _StorageBase):
65+
class ByteStorage(torch._C.CudaByteStorageBase, _CudaStorageBase):
4766
pass
48-
class HalfStorage(torch._C.CudaHalfStorageBase, _StorageBase):
67+
class HalfStorage(torch._C.CudaHalfStorageBase, _CudaStorageBase):
4968
pass
5069

51-
class DoubleTensor(torch._C.CudaDoubleTensorBase, _TensorBase):
70+
class DoubleTensor(torch._C.CudaDoubleTensorBase, _CudaTensorBase):
5271
def is_signed(self):
5372
return True
54-
class FloatTensor(torch._C.CudaFloatTensorBase, _TensorBase):
73+
class FloatTensor(torch._C.CudaFloatTensorBase, _CudaTensorBase):
5574
def is_signed(self):
5675
return True
57-
class LongTensor(torch._C.CudaLongTensorBase, _TensorBase):
76+
class LongTensor(torch._C.CudaLongTensorBase, _CudaTensorBase):
5877
def is_signed(self):
5978
return True
60-
class IntTensor(torch._C.CudaIntTensorBase, _TensorBase):
79+
class IntTensor(torch._C.CudaIntTensorBase, _CudaTensorBase):
6180
def is_signed(self):
6281
return True
63-
class ShortTensor(torch._C.CudaShortTensorBase, _TensorBase):
82+
class ShortTensor(torch._C.CudaShortTensorBase, _CudaTensorBase):
6483
def is_signed(self):
6584
return True
66-
class CharTensor(torch._C.CudaCharTensorBase, _TensorBase):
85+
class CharTensor(torch._C.CudaCharTensorBase, _CudaTensorBase):
6786
def is_signed(self):
6887
# TODO
6988
return False
70-
class ByteTensor(torch._C.CudaByteTensorBase, _TensorBase):
89+
class ByteTensor(torch._C.CudaByteTensorBase, _CudaTensorBase):
7190
def is_signed(self):
7291
return False
73-
class HalfTensor(torch._C.CudaHalfTensorBase, _TensorBase):
92+
class HalfTensor(torch._C.CudaHalfTensorBase, _CudaTensorBase):
7493
def is_signed(self):
7594
return True
7695

96+
7797
torch._storage_classes.add(DoubleStorage)
7898
torch._storage_classes.add(FloatStorage)
7999
torch._storage_classes.add(LongStorage)
@@ -90,36 +110,33 @@ def is_signed(self):
90110
torch._tensor_classes.add(CharTensor)
91111
torch._tensor_classes.add(ByteTensor)
92112

93-
@contextlib.contextmanager
94-
def device(idx):
95-
prev_idx = torch._C._cuda_getDevice()
96-
torch._C._cuda_setDevice(idx)
97-
yield
98-
torch._C._cuda_setDevice(prev_idx)
99113

100-
@contextlib.contextmanager
101-
def _dummy_ctx():
102-
yield
103-
104-
def _tensor_cuda(self, idx=None):
114+
def _cuda(self, idx=None):
105115
# This already is a CUDA tensor.
106116
# Let's check if it needs to be transfered to another GPU.
107117
if hasattr(self, 'getDevice'):
108118
target_device = idx if idx else torch._C._cuda_getDevice()
109119
if self.getDevice() != target_device:
110120
with device(target_device):
111121
return type(self)(self.size()).copy_(self)
122+
else:
123+
return self
112124
else:
113125
ctx = device(idx) if idx else _dummy_ctx()
114126
with ctx:
115127
return self.type(getattr(torch.cuda, self.__class__.__name__))
116-
_TensorBase.cuda = _tensor_cuda
117128

118-
def _tensor_cpu(self):
129+
130+
def _cpu(self):
119131
return self.type(getattr(torch, self.__class__.__name__))
120-
_TensorBase.cpu = _tensor_cpu
121132

122-
def deviceCount():
123-
return torch._C._cuda_getDeviceCount()
133+
134+
from ..Tensor import _TensorBase
135+
from ..Storage import _StorageBase
136+
_TensorBase.cuda = _cuda
137+
_TensorBase.cpu = _cpu
138+
_StorageBase.cuda = _cuda
139+
_StorageBase.cpu = _cpu
140+
124141

125142
assert torch._C._cuda_init()

torch/cuda/storage.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from . import device, _dummy_ctx
2+
from ..Storage import _StorageBase
3+
4+
5+
class _CudaStorageBase(_StorageBase):
6+
7+
def type(self, *args, **kwargs):
8+
source_device = self.getDevice()
9+
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
10+
with ctx:
11+
return super(_CudaStorageBase, self).type(*args, **kwargs)
12+

torch/cuda/tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from . import device, _dummy_ctx
2+
from ..Tensor import _TensorBase
3+
4+
5+
class _CudaTensorBase(_TensorBase):
6+
7+
def type(self, *args, **kwargs):
8+
source_device = self.getDevice()
9+
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
10+
with ctx:
11+
return super(_CudaTensorBase, self).type(*args, **kwargs)
12+

0 commit comments

Comments
 (0)