Skip to content

Commit 1875c2e

Browse files
hameerabbasifacebook-github-bot
authored andcommitted
Add torch.Tensor.as_subclass method. (#34369)
Summary: This is according to pytorch/rfcs#3. Pull Request resolved: #34369 Differential Revision: D20963929 Pulled By: ezyang fbshipit-source-id: e618af6fd36e1dfaeda617162314ad5840f55358
1 parent 7c825ba commit 1875c2e

File tree

4 files changed

+83
-2
lines changed

4 files changed

+83
-2
lines changed

docs/source/tensors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ view of a storage and defines numeric operations on it.
353353
.. automethod:: lt_
354354
.. automethod:: lu
355355
.. automethod:: lu_solve
356+
.. automethod:: as_subclass
356357
.. automethod:: map_
357358
.. automethod:: masked_scatter_
358359
.. automethod:: masked_scatter

test/test_torch.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,61 @@ def test_dtype_out_match(self):
920920
d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
921921
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), out=d, dtype=torch.float32))
922922

923+
def test_as_subclass(self):
924+
class SubTensor(torch.Tensor):
925+
member_var = object()
926+
927+
t0 = torch.tensor(0)
928+
t1 = torch.tensor([1, 2])
929+
t2 = torch.tensor([[3, 4], [5, 6]])
930+
931+
s0 = t0.as_subclass(SubTensor)
932+
s1 = t1.as_subclass(SubTensor)
933+
s2 = t2.as_subclass(SubTensor)
934+
935+
# Check that the correct type is returned.
936+
self.assertTrue(type(s0) is SubTensor)
937+
self.assertTrue(type(s1) is SubTensor)
938+
self.assertTrue(type(s2) is SubTensor)
939+
940+
# Check that the data is equal.
941+
self.assertEqual(t0, s0)
942+
self.assertEqual(t1, s1)
943+
self.assertEqual(t2, s2)
944+
945+
t0[()] = 1
946+
t1[1] = 3
947+
t2[1, 1] = 7
948+
949+
# Check that the data is equal even after modification.
950+
self.assertEqual(t0, s0)
951+
self.assertEqual(t1, s1)
952+
self.assertEqual(t2, s2)
953+
954+
# Check that member variables are passed through.
955+
self.assertTrue(s0.member_var is SubTensor.member_var)
956+
self.assertTrue(s1.member_var is SubTensor.member_var)
957+
self.assertTrue(s2.member_var is SubTensor.member_var)
958+
959+
# Test that autograd is propagated.
960+
t = torch.tensor(5, dtype=torch.float32, requires_grad=True)
961+
962+
# Run a calculation on the tensor.
963+
exp_t = torch.exp(t)
964+
965+
# Cast exp_t to a subclass.
966+
exp_s = exp_t.as_subclass(SubTensor)
967+
968+
# Make sure that t.grad was initially None
969+
self.assertTrue(t.grad is None)
970+
971+
# Run the autograd calculation.
972+
exp_s.backward()
973+
974+
# Make sure autograd was propagated to the original tensor
975+
# declared with requires_grad.
976+
self.assertTrue(t.grad is not None)
977+
923978
def test_constructor_dtypes(self):
924979
default_type = torch.Tensor().type()
925980
self.assertIs(torch.Tensor().dtype, torch.get_default_dtype())

torch/_tensor_docs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3656,3 +3656,12 @@ def callable(a, b) -> number
36563656
If ``n`` is the number of dimensions in ``x``,
36573657
``x.T`` is equivalent to ``x.permute(n-1, n-2, ..., 0)``.
36583658
""")
3659+
3660+
add_docstr_all('as_subclass',
3661+
r"""
3662+
as_subclass(cls) -> Tensor
3663+
3664+
Makes a ``cls`` instance with the same data pointer as ``self``. Changes
3665+
in the output mirror changes in ``self``, and the output stays attached
3666+
to the autograd graph. ``cls`` must be a subclass of ``Tensor``.
3667+
""")

torch/csrc/autograd/python_variable.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,22 @@ static PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject
143143
END_HANDLE_TH_ERRORS
144144
}
145145

146-
// Instantiates a subclass of torch.Tensor. Used by nn.Parameter()
146+
// Instantiates a subclass of self with the same data.
147+
static PyObject* THPVariable_as_subclass(THPVariable* self, PyObject* args, PyObject* kwargs) {
148+
HANDLE_TH_ERRORS
149+
static PythonArgParser parser({
150+
"as_subclass(PyObject* cls)",
151+
});
152+
ParsedArgs<1> parsed_args{};
153+
auto r = parser.parse(args, kwargs, parsed_args);
154+
PyObject* cls = r.pyobject(0);
155+
if (!PyType_Check(cls)) {
156+
throw TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name);
157+
}
158+
return THPVariable_NewWithVar((PyTypeObject*)cls, self->cdata.alias());
159+
END_HANDLE_TH_ERRORS
160+
}
161+
147162
static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, PyObject* kwargs) {
148163
HANDLE_TH_ERRORS
149164
static PythonArgParser parser({
@@ -539,7 +554,8 @@ static PyMappingMethods THPVariable_as_mapping = {
539554
};
540555

541556
static PyMethodDef extra_methods[] = {
542-
{"_make_subclass", (PyCFunction)(void(*)(void))THPVariable_make_subclass, METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr},
557+
{"as_subclass", (PyCFunction)THPVariable_as_subclass, METH_VARARGS | METH_KEYWORDS, nullptr},
558+
{"_make_subclass", (PyCFunction)THPVariable_make_subclass, METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr},
543559
{nullptr}
544560
};
545561

0 commit comments

Comments
 (0)