Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,29 @@ TEST_F(AtenXlaTensorTest, TestSubScalarInPlace) {
});
}

TEST_F(AtenXlaTensorTest, TestSymSizes) {
ForEachDevice([&](const torch::Device& device) {
torch::Tensor a = torch::rand({2, 3}, torch::TensorOptions(torch::kFloat));
torch::Tensor xla_a = CopyToDevice(a, device);
ASSERT_EQ(a.sym_sizes().at(0).expect_int(), 2);
ASSERT_EQ(a.sym_sizes().at(0).is_symbolic(), false);

torch::Tensor b = torch::tensor({{0.0, 1.0}, {0.0, 0.0}},
torch::TensorOptions(torch::kFloat));
torch::Tensor xla_b = CopyToDevice(b, device);
xla_b = torch::nonzero(xla_b);
auto s0 = xla_b.sym_sizes().at(0);
ASSERT_EQ(s0.is_symbolic(), true);
auto sininode =
Copy link
Collaborator Author

@miladm miladm Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is the variable name meant to be sininode?

@Krovatkin

dynamic_cast<XLASymIntNodeImpl*>(s0.toSymIntNodeImpl().get());
auto snode =
std::dynamic_pointer_cast<torch_xla::SizeNode>(sininode->node());
ASSERT_TRUE(snode);
ASSERT_EQ(snode->getStaticValue(), 4);
ASSERT_EQ(snode->getDynamicValue(), 1);
});
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a dynamic test? You can call non_zero but you would need to make sure https://github.com/pytorch/xla/blob/master/test/cpp/run_tests.sh#L11 is set so non_zero won't fallback to cpu.

TEST_F(AtenXlaTensorTest, TestMul) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
Expand Down
31 changes: 28 additions & 3 deletions torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "torch/csrc/lazy/backend/backend_interface.h"
#include "torch/csrc/lazy/core/tensor.h"
#include "torch/csrc/lazy/core/tensor_util.h"
#include "torch/csrc/lazy/core/util.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir_builder.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/tensor_util.h"

namespace torch_xla {
Expand Down Expand Up @@ -115,9 +119,9 @@ at::IntArrayRef XLATensorImpl::sizes_custom() const {
}

c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering what the "custom" in sym_sizes_custom() implies

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom means that a given TensorImpl can implement sym_sizes in whichever way they see fit. Note, sym_sizes_custom is a virtual method whereas sym_sizes() isn't. sym_sizes() checks if a given tensor impl implements CustomSizes policy and if so calls a virtual sym_sizes. This is kinda convoluted and was done to make sure we don't pay a virtual call penalty on the fast path: cpu and cuda tensors don't need the custom_sym_sizes implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

auto sizes = sizes_custom();
return c10::SymIntArrayRef(reinterpret_cast<const c10::SymInt*>(sizes.data()),
sizes.size());
// N.B. SetupSizeProperties also updates sym_sizes_
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
}

c10::SymInt XLATensorImpl::sym_numel_custom() const {
Expand Down Expand Up @@ -168,10 +172,31 @@ void XLATensorImpl::SetupSizeProperties() {
for (int i = 0; i < updated_strides.size(); i++) {
Copy link
Collaborator Author

@miladm miladm Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we move this for loop inside SetupSymSizeProperties()? This makes SetupSizeProperties and SetupSymSizeProperties implementations consistent. Is there a reason we want them called outside SetupSymSizeProperties?

@Krovatkin

Copy link
Contributor

@Krovatkin Krovatkin Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understand your point, you are suggesting to duplicate the strides logic inside SetupSymSizeProperties as well, so it updates both sym_sizes and sym_strides the same way SetupSizeProperties updates sizes and strides?

If this is your concern, I believe it's fine for now, since sym_strides is still returning static strides now.
We should figure out if we need support for dynamic strides as well or whether it's okay for sym_strides to continue returning static strides.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang: based on our offline conversation about sym_strides under functionalization, it seems we assume no need for dynamic strides; correct?

sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i];
}
SetupSymSizeProperties();
generation_ = generation;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

}
}

void XLATensorImpl::SetupSymSizeProperties() {
auto shape = tensor_->shape();
auto rank = shape.get().rank();
std::vector<c10::SymInt> sym_sizes;
sym_sizes.reserve(rank);

XLAIrBuilder a = XLAIrBuilder();
for (auto i : c10::irange(rank)) {
if (shape.get().is_dynamic_dimension(i)) {
auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i);
auto symint_node = c10::make_intrusive<XLASymIntNodeImpl>(dim_node);
auto sn = symint_node->toSymInt();
sym_sizes.push_back(sn);
} else {
sym_sizes.push_back(c10::SymInt(shape.get().dimensions(i)));
}
}
sym_sizes_ = sym_sizes;
}

caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) {
return c10::scalarTypeToTypeMeta(tensor.dtype());
}
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ class XLATensorImpl : public c10::TensorImpl {

private:
void SetupSizeProperties();
void SetupSymSizeProperties();

static caffe2::TypeMeta GetTypeMeta(const XLATensor& tensor);

XLATensorPtr tensor_;
std::vector<c10::SymInt> sym_sizes_;
size_t generation_ = 0;
};

Expand Down