-
Notifications
You must be signed in to change notification settings - Fork 566
Add sym_sizes
with dynamic shape support
#3909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -279,6 +279,29 @@ TEST_F(AtenXlaTensorTest, TestSubScalarInPlace) { | |
}); | ||
} | ||
|
||
TEST_F(AtenXlaTensorTest, TestSymSizes) { | ||
ForEachDevice([&](const torch::Device& device) { | ||
torch::Tensor a = torch::rand({2, 3}, torch::TensorOptions(torch::kFloat)); | ||
torch::Tensor xla_a = CopyToDevice(a, device); | ||
ASSERT_EQ(a.sym_sizes().at(0).expect_int(), 2); | ||
ASSERT_EQ(a.sym_sizes().at(0).is_symbolic(), false); | ||
|
||
torch::Tensor b = torch::tensor({{0.0, 1.0}, {0.0, 0.0}}, | ||
torch::TensorOptions(torch::kFloat)); | ||
torch::Tensor xla_b = CopyToDevice(b, device); | ||
xla_b = torch::nonzero(xla_b); | ||
auto s0 = xla_b.sym_sizes().at(0); | ||
ASSERT_EQ(s0.is_symbolic(), true); | ||
auto sininode = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: is the variable name meant to be |
||
dynamic_cast<XLASymIntNodeImpl*>(s0.toSymIntNodeImpl().get()); | ||
auto snode = | ||
std::dynamic_pointer_cast<torch_xla::SizeNode>(sininode->node()); | ||
ASSERT_TRUE(snode); | ||
ASSERT_EQ(snode->getStaticValue(), 4); | ||
ASSERT_EQ(snode->getDynamicValue(), 1); | ||
}); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add a dynamic test? You can call |
||
TEST_F(AtenXlaTensorTest, TestMul) { | ||
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); | ||
torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,15 @@ | |
|
||
#include "tensorflow/compiler/xla/xla_client/computation_client.h" | ||
#include "tensorflow/compiler/xla/xla_client/debug_macros.h" | ||
#include "torch/csrc/lazy/backend/backend_interface.h" | ||
#include "torch/csrc/lazy/core/tensor.h" | ||
#include "torch/csrc/lazy/core/tensor_util.h" | ||
#include "torch/csrc/lazy/core/util.h" | ||
#include "torch_xla/csrc/aten_xla_bridge.h" | ||
#include "torch_xla/csrc/device.h" | ||
#include "torch_xla/csrc/ir_builder.h" | ||
#include "torch_xla/csrc/layout_manager.h" | ||
#include "torch_xla/csrc/ops/dynamic_ir.h" | ||
#include "torch_xla/csrc/tensor_util.h" | ||
|
||
namespace torch_xla { | ||
|
@@ -115,9 +119,9 @@ at::IntArrayRef XLATensorImpl::sizes_custom() const { | |
} | ||
|
||
c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering what the "custom" in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks |
||
auto sizes = sizes_custom(); | ||
return c10::SymIntArrayRef(reinterpret_cast<const c10::SymInt*>(sizes.data()), | ||
sizes.size()); | ||
// N.B. SetupSizeProperties also updates sym_sizes_ | ||
const_cast<XLATensorImpl*>(this)->SetupSizeProperties(); | ||
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size()); | ||
} | ||
|
||
c10::SymInt XLATensorImpl::sym_numel_custom() const { | ||
|
@@ -168,10 +172,31 @@ void XLATensorImpl::SetupSizeProperties() { | |
for (int i = 0; i < updated_strides.size(); i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we move this for loop inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I understand your point, you are suggesting to duplicate the strides logic inside If this is your concern, I believe it's fine for now, since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ezyang: based on our offline conversation about |
||
sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i]; | ||
} | ||
SetupSymSizeProperties(); | ||
generation_ = generation; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
} | ||
} | ||
|
||
void XLATensorImpl::SetupSymSizeProperties() { | ||
auto shape = tensor_->shape(); | ||
auto rank = shape.get().rank(); | ||
std::vector<c10::SymInt> sym_sizes; | ||
sym_sizes.reserve(rank); | ||
|
||
XLAIrBuilder a = XLAIrBuilder(); | ||
for (auto i : c10::irange(rank)) { | ||
if (shape.get().is_dynamic_dimension(i)) { | ||
auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i); | ||
auto symint_node = c10::make_intrusive<XLASymIntNodeImpl>(dim_node); | ||
auto sn = symint_node->toSymInt(); | ||
sym_sizes.push_back(sn); | ||
} else { | ||
sym_sizes.push_back(c10::SymInt(shape.get().dimensions(i))); | ||
} | ||
} | ||
sym_sizes_ = sym_sizes; | ||
} | ||
|
||
caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) { | ||
return c10::scalarTypeToTypeMeta(tensor.dtype()); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.