diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index f6036287e471..787908ad3b71 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -279,6 +279,29 @@ TEST_F(AtenXlaTensorTest, TestSubScalarInPlace) { }); } +TEST_F(AtenXlaTensorTest, TestSymSizes) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor a = torch::rand({2, 3}, torch::TensorOptions(torch::kFloat)); + torch::Tensor xla_a = CopyToDevice(a, device); + ASSERT_EQ(a.sym_sizes().at(0).expect_int(), 2); + ASSERT_EQ(a.sym_sizes().at(0).is_symbolic(), false); + + torch::Tensor b = torch::tensor({{0.0, 1.0}, {0.0, 0.0}}, + torch::TensorOptions(torch::kFloat)); + torch::Tensor xla_b = CopyToDevice(b, device); + xla_b = torch::nonzero(xla_b); + auto s0 = xla_b.sym_sizes().at(0); + ASSERT_EQ(s0.is_symbolic(), true); + auto sininode = + dynamic_cast(s0.toSymIntNodeImpl().get()); + auto snode = + std::dynamic_pointer_cast(sininode->node()); + ASSERT_TRUE(snode); + ASSERT_EQ(snode->getStaticValue(), 4); + ASSERT_EQ(snode->getDynamicValue(), 1); + }); +} + TEST_F(AtenXlaTensorTest, TestMul) { torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 62cfb39b6c77..d21686f788de 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -6,11 +6,15 @@ #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "torch/csrc/lazy/backend/backend_interface.h" +#include "torch/csrc/lazy/core/tensor.h" #include "torch/csrc/lazy/core/tensor_util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/ir_builder.h" #include "torch_xla/csrc/layout_manager.h" +#include "torch_xla/csrc/ops/dynamic_ir.h" #include "torch_xla/csrc/tensor_util.h" namespace torch_xla { @@ -115,9 +119,9 @@ at::IntArrayRef XLATensorImpl::sizes_custom() const { } c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const { - auto sizes = sizes_custom(); - return c10::SymIntArrayRef(reinterpret_cast(sizes.data()), - sizes.size()); + // N.B. SetupSizeProperties also updates sym_sizes_ + const_cast(this)->SetupSizeProperties(); + return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size()); } c10::SymInt XLATensorImpl::sym_numel_custom() const { @@ -168,10 +172,31 @@ void XLATensorImpl::SetupSizeProperties() { for (int i = 0; i < updated_strides.size(); i++) { sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i]; } + SetupSymSizeProperties(); generation_ = generation; } } +void XLATensorImpl::SetupSymSizeProperties() { + auto shape = tensor_->shape(); + auto rank = shape.get().rank(); + std::vector sym_sizes; + sym_sizes.reserve(rank); + + XLAIrBuilder a = XLAIrBuilder(); + for (auto i : c10::irange(rank)) { + if (shape.get().is_dynamic_dimension(i)) { + auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i); + auto symint_node = c10::make_intrusive(dim_node); + auto sn = symint_node->toSymInt(); + sym_sizes.push_back(sn); + } else { + sym_sizes.push_back(c10::SymInt(shape.get().dimensions(i))); + } + } + sym_sizes_ = sym_sizes; +} + caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) { return c10::scalarTypeToTypeMeta(tensor.dtype()); } diff --git a/torch_xla/csrc/tensor_impl.h b/torch_xla/csrc/tensor_impl.h index 791ccb953637..2e76fe6e8289 100644 --- a/torch_xla/csrc/tensor_impl.h +++ b/torch_xla/csrc/tensor_impl.h @@ -51,10 +51,12 @@ class XLATensorImpl : public c10::TensorImpl { private: void SetupSizeProperties(); + void SetupSymSizeProperties(); static caffe2::TypeMeta GetTypeMeta(const XLATensor& tensor); XLATensorPtr tensor_; + std::vector sym_sizes_; size_t generation_ = 0; };