@@ -119,7 +119,8 @@ at::IntArrayRef XLATensorImpl::sizes_custom() const {
119
119
}
120
120
121
121
c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom () const {
122
- const_cast <XLATensorImpl*>(this )->SetupSymSizeProperties ();
122
+ // N.B. SetupSizeProperties also updates sym_sizes_
123
+ const_cast <XLATensorImpl*>(this )->SetupSizeProperties ();
123
124
return c10::SymIntArrayRef (sym_sizes_.data (), sym_sizes_.size ());
124
125
}
125
126
@@ -171,37 +172,29 @@ void XLATensorImpl::SetupSizeProperties() {
171
172
for (int i = 0 ; i < updated_strides.size (); i++) {
172
173
sizes_and_strides_.stride_at_unchecked (i) = updated_strides[i];
173
174
}
175
+ SetupSymSizeProperties ();
174
176
generation_ = generation;
175
177
}
176
178
}
177
179
178
180
void XLATensorImpl::SetupSymSizeProperties () {
179
- size_t generation = tensor_->generation ();
180
- if (generation != generation_) {
181
- // Fill up the basic dimension data members which the base class
182
- // implementation uses in its APIs.
183
- auto shape = tensor_->shape ();
184
- auto rank = tensor_->shape ().get ().rank ();
185
- c10::SmallVector<c10::SymInt, 5 > sym_sizes;
186
- numel_ = 1 ;
187
- XLAIrBuilder a = XLAIrBuilder ();
188
- for (auto i : c10::irange (rank)) {
189
- if (tensor_->shape ().get ().is_dynamic_dimension (i)) {
190
- auto dim_node = a.MakeSizeNode (tensor_->GetIrValue (), i);
191
- auto symint_node =
192
- c10::make_intrusive<torch::lazy::SymIntNodeImpl>(dim_node);
193
- auto sn = symint_node->toSymInt ();
194
- sym_sizes_.push_back (sn);
195
- /* TODO(miladm): verify numel_ calculation after adding a dynamic op
196
- */
197
- numel_ *= dynamic_cast <SizeNode*>(dim_node.get ())->getStaticValue ();
198
- } else {
199
- sym_sizes_.push_back (c10::SymInt (tensor_->shape ().get ().dimensions (i)));
200
- numel_ *= tensor_->shape ().get ().dimensions (i);
201
- }
181
+ auto shape = tensor_->shape ();
182
+ auto rank = shape.get ().rank ();
183
+ std::vector<c10::SymInt> sym_sizes;
184
+ sym_sizes.reserve (rank);
185
+
186
+ XLAIrBuilder a = XLAIrBuilder ();
187
+ for (auto i : c10::irange (rank)) {
188
+ if (shape.get ().is_dynamic_dimension (i)) {
189
+ auto dim_node = a.MakeSizeNode (tensor_->GetIrValue (), i);
190
+ auto symint_node = c10::make_intrusive<XLASymIntNodeImpl>(dim_node);
191
+ auto sn = symint_node->toSymInt ();
192
+ sym_sizes.push_back (sn);
193
+ } else {
194
+ sym_sizes.push_back (c10::SymInt (shape.get ().dimensions (i)));
202
195
}
203
- generation_ = generation;
204
196
}
197
+ sym_sizes_ = sym_sizes;
205
198
}
206
199
207
200
caffe2::TypeMeta XLATensorImpl::GetTypeMeta (const XLATensor& tensor) {
0 commit comments