Skip to content

Commit a389ba8

Browse files
committed
Support optional_tensor_names in TorchAOBaseTensor
Summary: Allows subclasses inheriting from TorchAOBaseTensor to have optional tensor attributes, updated all common util functions to support `optional_tensor_names` list, including `__tensor_flatten__`, `__tensor_unflatten__`, ops like aten._to_copy, contiguous, alias etc. Test Plan: python test/test_utils.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2710, branch: jerryzh168/stack/17
1 parent abcddcf commit a389ba8

File tree

2 files changed

+154
-42
lines changed

2 files changed

+154
-42
lines changed

test/test_utils.py

Lines changed: 107 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -55,34 +55,11 @@ def __init__(self, data):
5555
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
5656
l.weight = torch.nn.Parameter(MyTensor(l.weight))
5757

58-
@skip_if_no_cuda()
59-
def test_default_impls(self):
60-
"""Making sure some common functions has default implementations, such as
61-
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
62-
"""
63-
64-
class MyTensor(TorchAOBaseTensor):
65-
tensor_data_names = ["qdata"]
66-
tensor_attribute_names = ["attr", "device"]
67-
68-
def __new__(cls, qdata, attr, device=None):
69-
shape = qdata.shape
70-
if device is None:
71-
device = qdata.device
72-
kwargs = {"device": device}
73-
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
74-
75-
def __init__(self, qdata, attr, device=None):
76-
self.qdata = qdata
77-
self.attr = attr
78-
79-
l = torch.nn.Linear(2, 3)
80-
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
81-
lp_tensor = l.weight
58+
def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
8259
# test __tensor_flatten__ and __tensor_unflatten__
83-
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
60+
tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__()
8461
tensor_data_dict = {
85-
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
62+
name: getattr(lp_tensor, name) for name in tensor_data_names
8663
}
8764
outer_size = lp_tensor.size()
8865
outer_stride = lp_tensor.stride()
@@ -100,31 +77,120 @@ def __init__(self, qdata, attr, device=None):
10077
self.assertEqual(lp_tensor.device, original_device)
10178

10279
# __repr__
103-
print(lp_tensor)
80+
_ = str(lp_tensor)
10481

10582
# other ops
10683
lp_tensor = lp_tensor.detach()
10784
# explicitly testing aten.alias
10885
lp_tensor = torch.ops.aten.alias(lp_tensor)
10986
lp_tensor = lp_tensor.clone()
110-
# making qdata not contiguous
111-
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous()
112-
lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1)
113-
self.assertFalse(lp_tensor.qdata.is_contiguous())
114-
lp_tensor = lp_tensor.contiguous()
115-
# making sure contiguous call works
116-
self.assertTrue(lp_tensor.qdata.is_contiguous())
87+
# get all tensor_data_names for both
88+
# non optional and valid optional tensors
89+
tensor_data_names = lp_tensor.tensor_data_names.copy()
90+
if hasattr(lp_tensor, "optional_tensor_data_names"):
91+
for tensor_data_name in lp_tensor.optional_tensor_data_names:
92+
if getattr(lp_tensor, tensor_data_name) is not None:
93+
tensor_data_names.append(tensor_data_name)
94+
95+
# for each of the tensor data, we try to
96+
# make it non-contiguous and then use
97+
# lp_tensor.contiguous() call to make sure
98+
# contiguous() works
99+
for tensor_data_name in tensor_data_names:
100+
tensor = getattr(lp_tensor, tensor_data_name)
101+
# making qdata not contiguous
102+
tensor = tensor.transpose(0, 1).contiguous()
103+
tensor = tensor.transpose(0, 1)
104+
setattr(lp_tensor, tensor_data_name, tensor)
105+
self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous())
106+
lp_tensor = lp_tensor.contiguous()
107+
# making sure contiguous call works
108+
self.assertTrue(getattr(lp_tensor, tensor_data_name).is_contiguous())
117109

118110
# copy_
111+
# making sure that initially tensor values are not the same so we can test copy_
112+
self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0])
113+
# copy_ requires the attributes to be the same
114+
for tensor_attr_name in lp_tensor.tensor_attribute_names:
115+
self.assertEqual(
116+
getattr(lp_tensor, tensor_attr_name),
117+
getattr(lp_tensor_for_copy, tensor_attr_name),
118+
)
119+
lp_tensor.copy_(lp_tensor_for_copy)
120+
# after copy_, the tensor values should match
121+
for tensor_data_name in tensor_data_names:
122+
self.assertTrue(
123+
torch.equal(
124+
getattr(lp_tensor, tensor_data_name),
125+
getattr(lp_tensor_for_copy, tensor_data_name),
126+
)
127+
)
128+
129+
@skip_if_no_cuda()
130+
def test_default_impls(self):
131+
"""Making sure some common functions has default implementations, such as
132+
__tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
133+
"""
134+
135+
class MyTensor(TorchAOBaseTensor):
136+
tensor_data_names = ["qdata"]
137+
tensor_attribute_names = ["attr", "device"]
138+
139+
def __new__(cls, qdata, attr, device=None):
140+
shape = qdata.shape
141+
if device is None:
142+
device = qdata.device
143+
kwargs = {"device": device}
144+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
145+
146+
def __init__(self, qdata, attr, device=None):
147+
self.qdata = qdata
148+
self.attr = attr
149+
150+
l = torch.nn.Linear(2, 3)
151+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
152+
lp_tensor = l.weight
153+
119154
another_tensor = torch.nn.Linear(2, 3).weight
120155
# attribute has to be the same
121-
another_lp_tensor = MyTensor(another_tensor, "attr")
122-
# initially tensor values are not the same
123-
self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
124-
lp_tensor.copy_(another_lp_tensor)
125-
self.assertEqual(lp_tensor.attr, "attr")
126-
# after copy_, the tensor values should match
127-
self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0])
156+
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
157+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
158+
159+
def test_default_impls_with_optional_data(self):
160+
class MyTensorWithOptionalData(TorchAOBaseTensor):
161+
tensor_data_names = ["qdata"]
162+
optional_tensor_data_names = ["zero_point"]
163+
tensor_attribute_names = ["attr", "device"]
164+
165+
def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
166+
shape = qdata.shape
167+
if device is None:
168+
device = qdata.device
169+
kwargs = {"device": device}
170+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
171+
172+
def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
173+
self.qdata = qdata
174+
self.zero_point = zero_point
175+
self.attr = attr
176+
177+
# test both the optional Tensor is None
178+
# and not None
179+
l = torch.nn.Linear(2, 3)
180+
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
181+
l = torch.nn.Linear(2, 3)
182+
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
183+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
184+
185+
l = torch.nn.Linear(2, 3)
186+
lp_tensor = MyTensorWithOptionalData(
187+
l.weight, torch.zeros_like(l.weight), "attr"
188+
)
189+
l = torch.nn.Linear(2, 3)
190+
lp_tensor_for_copy = MyTensorWithOptionalData(
191+
l.weight, torch.zeros_like(l.weight), "attr"
192+
)
193+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
128194

129195

130196
if __name__ == "__main__":

torchao/utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,16 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
463463
getattr(self, t_name).shape == getattr(src, t_name).shape
464464
for t_name in self.tensor_data_names
465465
)
466+
_optional_tensor_shape_match = True
467+
if hasattr(self, "optional_tensor_data_names"):
468+
# either both are None or both are not Tensors and the shape match
469+
_optional_tensor_shape_match = all(
470+
getattr(self, t_name).shape == getattr(src, t_name).shape
471+
if getattr(self, t_name) is not None
472+
else getattr(src, t_name) is None
473+
for t_name in self.optional_tensor_data_names
474+
)
475+
466476
_attr_match = all(
467477
getattr(self, a_name) == getattr(src, a_name)
468478
for a_name in self.tensor_attribute_names
@@ -471,6 +481,7 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
471481
type(self) == type(src)
472482
and self.shape == src.shape
473483
and _tensor_shape_match
484+
and _optional_tensor_shape_match
474485
and _attr_match
475486
)
476487

@@ -498,6 +509,14 @@ def _(func, types, args, kwargs):
498509
tensors = [
499510
getattr(self, name).to(device) for name in self.tensor_data_names
500511
]
512+
if hasattr(self, "optional_tensor_data_names"):
513+
for tensor_data_name in self.optional_tensor_data_names:
514+
maybe_tensor = getattr(self, tensor_data_name)
515+
if maybe_tensor is not None:
516+
tensors.append(maybe_tensor.to(device))
517+
else:
518+
tensors.append(None)
519+
501520
# change device
502521
tensor_attributes = [
503522
getattr(self, attr_name) if attr_name != "device" else device
@@ -699,7 +718,14 @@ def __tensor_flatten__(self):
699718
if hasattr(self, "tensor_data_names") and hasattr(
700719
self, "tensor_attribute_names"
701720
):
702-
return self.tensor_data_names, [
721+
tensor_data_names = self.tensor_data_names.copy()
722+
if hasattr(self, "optional_tensor_data_names"):
723+
for tensor_data_name in self.optional_tensor_data_names:
724+
maybe_tensor = getattr(self, tensor_data_name)
725+
if maybe_tensor is not None:
726+
tensor_data_names.append(tensor_data_name)
727+
728+
return tensor_data_names, [
703729
getattr(self, attr) for attr in self.tensor_attribute_names
704730
]
705731
raise NotImplementedError(
@@ -711,13 +737,27 @@ def __tensor_unflatten__(
711737
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
712738
):
713739
tensors = [tensor_data_dict[name] for name in cls.tensor_data_names]
740+
if hasattr(cls, "optional_tensor_data_names"):
741+
for tensor_data_name in cls.optional_tensor_data_names:
742+
if tensor_data_name in tensor_data_dict:
743+
tensors.append(tensor_data_dict[tensor_data_name])
744+
else:
745+
tensors.append(None)
714746
return cls(*tensors, *tensor_attributes)
715747

716748
def _apply_fn_to_data(self, fn):
717749
if hasattr(self, "tensor_data_names") and hasattr(
718750
self, "tensor_attribute_names"
719751
):
720752
tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names]
753+
if hasattr(self, "optional_tensor_data_names"):
754+
for tensor_data_name in self.optional_tensor_data_names:
755+
maybe_tensor = getattr(self, tensor_data_name)
756+
if maybe_tensor is not None:
757+
tensors.append(fn(maybe_tensor))
758+
else:
759+
tensors.append(None)
760+
721761
tensor_attributes = [
722762
getattr(self, attr) for attr in self.tensor_attribute_names
723763
]
@@ -738,6 +778,12 @@ def __repr__(self):
738778
repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}"
739779
for tensor_data_name in self.tensor_data_names[1:]:
740780
repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}"
781+
if hasattr(self, "optional_tensor_data_names"):
782+
for tensor_data_name in self.optional_tensor_data_names:
783+
repr_str += (
784+
f", {tensor_data_name}={getattr(self, tensor_data_name)}"
785+
)
786+
741787
for tensor_attribute_name in self.tensor_attribute_names:
742788
repr_str += (
743789
f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}"

0 commit comments

Comments
 (0)