@@ -55,34 +55,11 @@ def __init__(self, data):
55
55
with self .assertRaisesRegex (NotImplementedError , "arg_types" ):
56
56
l .weight = torch .nn .Parameter (MyTensor (l .weight ))
57
57
58
- @skip_if_no_cuda ()
59
- def test_default_impls (self ):
60
- """Making sure some common functions has default implementations, such as
61
- __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
62
- """
63
-
64
- class MyTensor (TorchAOBaseTensor ):
65
- tensor_data_names = ["qdata" ]
66
- tensor_attribute_names = ["attr" , "device" ]
67
-
68
- def __new__ (cls , qdata , attr , device = None ):
69
- shape = qdata .shape
70
- if device is None :
71
- device = qdata .device
72
- kwargs = {"device" : device }
73
- return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
74
-
75
- def __init__ (self , qdata , attr , device = None ):
76
- self .qdata = qdata
77
- self .attr = attr
78
-
79
- l = torch .nn .Linear (2 , 3 )
80
- l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" ))
81
- lp_tensor = l .weight
58
+ def _test_default_impls_helper (self , lp_tensor , lp_tensor_for_copy ):
82
59
# test __tensor_flatten__ and __tensor_unflatten__
83
- tensor_data_name_dict , tensor_attributes = lp_tensor .__tensor_flatten__ ()
60
+ tensor_data_names , tensor_attributes = lp_tensor .__tensor_flatten__ ()
84
61
tensor_data_dict = {
85
- name : getattr (lp_tensor , name ) for name in tensor_data_name_dict
62
+ name : getattr (lp_tensor , name ) for name in tensor_data_names
86
63
}
87
64
outer_size = lp_tensor .size ()
88
65
outer_stride = lp_tensor .stride ()
@@ -107,24 +84,102 @@ def __init__(self, qdata, attr, device=None):
107
84
# explicitly testing aten.alias
108
85
lp_tensor = torch .ops .aten .alias (lp_tensor )
109
86
lp_tensor = lp_tensor .clone ()
110
- # making qdata not contiguous
111
- lp_tensor .qdata = lp_tensor .qdata .transpose (0 , 1 ).contiguous ()
112
- lp_tensor .qdata = lp_tensor .qdata .transpose (0 , 1 )
113
- self .assertFalse (lp_tensor .qdata .is_contiguous ())
114
- lp_tensor = lp_tensor .contiguous ()
115
- # making sure contiguous call works
116
- self .assertTrue (lp_tensor .qdata .is_contiguous ())
87
+ # get all tensor_data_names for both
88
+ # non optional and valid optional tensors
89
+ tensor_data_names = lp_tensor .tensor_data_names .copy ()
90
+ if hasattr (lp_tensor , "optional_tensor_data_names" ):
91
+ for tensor_data_name in lp_tensor .optional_tensor_data_names :
92
+ if getattr (lp_tensor , tensor_data_name ) is not None :
93
+ tensor_data_names .append (tensor_data_name )
94
+
95
+ # for each of the tensor data, we try to
96
+ # make it non-contiguous and then use
97
+ # lp_tensor.contiguous() call to make sure
98
+ # contiguous() works
99
+ for tensor_data_name in tensor_data_names :
100
+ tensor = getattr (lp_tensor , tensor_data_name )
101
+ # making qdata not contiguous
102
+ tensor = tensor .transpose (0 , 1 ).contiguous ()
103
+ tensor = tensor .transpose (0 , 1 )
104
+ setattr (lp_tensor , tensor_data_name , tensor )
105
+ self .assertFalse (getattr (lp_tensor , tensor_data_name ).is_contiguous ())
106
+ lp_tensor = lp_tensor .contiguous ()
107
+ # making sure contiguous call works
108
+ self .assertTrue (getattr (lp_tensor , tensor_data_name ).is_contiguous ())
117
109
118
110
# copy_
119
- another_tensor = torch .nn .Linear (2 , 3 ).weight
120
- # attribute has to be the same
121
- another_lp_tensor = MyTensor (another_tensor , "attr" )
122
111
# initially tensor values are not the same
123
- self .assertNotEqual (lp_tensor .qdata [0 ][0 ], another_lp_tensor .qdata [0 ][0 ])
124
- lp_tensor .copy_ (another_lp_tensor )
112
+ self .assertNotEqual (lp_tensor .qdata [0 ][0 ], lp_tensor_for_copy .qdata [0 ][0 ])
113
+ lp_tensor .copy_ (lp_tensor_for_copy )
125
114
self .assertEqual (lp_tensor .attr , "attr" )
126
115
# after copy_, the tensor values should match
127
- self .assertEqual (lp_tensor .qdata [0 ][0 ], another_lp_tensor .qdata [0 ][0 ])
116
+ self .assertEqual (lp_tensor .qdata [0 ][0 ], lp_tensor_for_copy .qdata [0 ][0 ])
117
+
118
+ @skip_if_no_cuda ()
119
+ def test_default_impls (self ):
120
+ """Making sure some common functions has default implementations, such as
121
+ __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to
122
+ """
123
+
124
+ class MyTensor (TorchAOBaseTensor ):
125
+ tensor_data_names = ["qdata" ]
126
+ tensor_attribute_names = ["attr" , "device" ]
127
+
128
+ def __new__ (cls , qdata , attr , device = None ):
129
+ shape = qdata .shape
130
+ if device is None :
131
+ device = qdata .device
132
+ kwargs = {"device" : device }
133
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
134
+
135
+ def __init__ (self , qdata , attr , device = None ):
136
+ self .qdata = qdata
137
+ self .attr = attr
138
+
139
+ l = torch .nn .Linear (2 , 3 )
140
+ l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" ))
141
+ lp_tensor = l .weight
142
+
143
+ another_tensor = torch .nn .Linear (2 , 3 ).weight
144
+ # attribute has to be the same
145
+ lp_tensor_for_copy = MyTensor (another_tensor , "attr" )
146
+ self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
147
+
148
+ def test_default_impls_with_optional_data (self ):
149
+ class MyTensorWithOptionalData (TorchAOBaseTensor ):
150
+ tensor_data_names = ["qdata" ]
151
+ optional_tensor_data_names = ["zero_point" ]
152
+ tensor_attribute_names = ["attr" , "device" ]
153
+
154
+ def __new__ (cls , qdata , zero_point = None , attr = 1.0 , device = None ):
155
+ shape = qdata .shape
156
+ if device is None :
157
+ device = qdata .device
158
+ kwargs = {"device" : device }
159
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
160
+
161
+ def __init__ (self , qdata , zero_point = None , attr = 1.0 , device = None ):
162
+ self .qdata = qdata
163
+ self .zero_point = zero_point
164
+ self .attr = attr
165
+
166
+ # test both the optional Tensor is None
167
+ # and not None
168
+ l = torch .nn .Linear (2 , 3 )
169
+ lp_tensor = MyTensorWithOptionalData (l .weight , None , "attr" )
170
+ l = torch .nn .Linear (2 , 3 )
171
+ lp_tensor_for_copy = MyTensorWithOptionalData (l .weight , None , "attr" )
172
+ self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
173
+
174
+ l = torch .nn .Linear (2 , 3 )
175
+ lp_tensor = MyTensorWithOptionalData (
176
+ l .weight , torch .zeros_like (l .weight ), "attr"
177
+ )
178
+ l = torch .nn .Linear (2 , 3 )
179
+ lp_tensor_for_copy = MyTensorWithOptionalData (
180
+ l .weight , torch .zeros_like (l .weight ), "attr"
181
+ )
182
+ self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
128
183
129
184
130
185
if __name__ == "__main__" :
0 commit comments