Skip to content

Commit 06e9ecf

Browse files
committed
fix test failures
1 parent 7b6e44a commit 06e9ecf

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

torch_xla/csrc/tensor_impl.cpp

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ at::IntArrayRef XLATensorImpl::sizes_custom() const {
119119
}
120120

121121
c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const {
122-
const_cast<XLATensorImpl*>(this)->SetupSymSizeProperties();
122+
// N.B. SetupSizeProperties also updates sym_sizes_
123+
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
123124
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
124125
}
125126

@@ -171,37 +172,29 @@ void XLATensorImpl::SetupSizeProperties() {
171172
for (int i = 0; i < updated_strides.size(); i++) {
172173
sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i];
173174
}
175+
SetupSymSizeProperties();
174176
generation_ = generation;
175177
}
176178
}
177179

178180
void XLATensorImpl::SetupSymSizeProperties() {
179-
size_t generation = tensor_->generation();
180-
if (generation != generation_) {
181-
// Fill up the basic dimension data members which the base class
182-
// implementation uses in its APIs.
183-
auto shape = tensor_->shape();
184-
auto rank = tensor_->shape().get().rank();
185-
c10::SmallVector<c10::SymInt, 5> sym_sizes;
186-
numel_ = 1;
187-
XLAIrBuilder a = XLAIrBuilder();
188-
for (auto i : c10::irange(rank)) {
189-
if (tensor_->shape().get().is_dynamic_dimension(i)) {
190-
auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i);
191-
auto symint_node =
192-
c10::make_intrusive<torch::lazy::SymIntNodeImpl>(dim_node);
193-
auto sn = symint_node->toSymInt();
194-
sym_sizes_.push_back(sn);
195-
/*TODO(miladm): verify numel_ calculation after adding a dynamic op
196-
*/
197-
numel_ *= dynamic_cast<SizeNode*>(dim_node.get())->getStaticValue();
198-
} else {
199-
sym_sizes_.push_back(c10::SymInt(tensor_->shape().get().dimensions(i)));
200-
numel_ *= tensor_->shape().get().dimensions(i);
201-
}
181+
auto shape = tensor_->shape();
182+
auto rank = shape.get().rank();
183+
std::vector<c10::SymInt> sym_sizes;
184+
sym_sizes.reserve(rank);
185+
186+
XLAIrBuilder a = XLAIrBuilder();
187+
for (auto i : c10::irange(rank)) {
188+
if (shape.get().is_dynamic_dimension(i)) {
189+
auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i);
190+
auto symint_node = c10::make_intrusive<XLASymIntNodeImpl>(dim_node);
191+
auto sn = symint_node->toSymInt();
192+
sym_sizes.push_back(sn);
193+
} else {
194+
sym_sizes.push_back(c10::SymInt(shape.get().dimensions(i)));
202195
}
203-
generation_ = generation;
204196
}
197+
sym_sizes_ = sym_sizes;
205198
}
206199

207200
caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) {

0 commit comments

Comments
 (0)