Skip to content

Commit e4befa3

Browse files
committed
addderss Jack's feedback
1 parent 1deb416 commit e4befa3

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,20 @@ TEST_F(AtenXlaTensorTest, TestSymSizes) {
285285
torch::Tensor xla_a = CopyToDevice(a, device);
286286
ASSERT_EQ(a.sym_sizes().at(0).expect_int(), 2);
287287
ASSERT_EQ(a.sym_sizes().at(0).is_symbolic(), false);
288+
289+
torch::Tensor b = torch::tensor({{0.0, 1.0}, {0.0, 0.0}},
290+
torch::TensorOptions(torch::kFloat));
291+
torch::Tensor xla_b = CopyToDevice(b, device);
292+
xla_b = torch::nonzero(xla_b);
293+
auto s0 = xla_b.sym_sizes().at(0);
294+
ASSERT_EQ(s0.is_symbolic(), true);
295+
auto sininode =
296+
dynamic_cast<XLASymIntNodeImpl*>(s0.toSymIntNodeImpl().get());
297+
auto snode =
298+
std::dynamic_pointer_cast<torch_xla::SizeNode>(sininode->node());
299+
ASSERT_TRUE(snode);
300+
ASSERT_EQ(snode->getStaticValue(), 4);
301+
ASSERT_EQ(snode->getDynamicValue(), 1);
288302
});
289303
}
290304

torch_xla/csrc/tensor_impl.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,8 @@
33
#include <ATen/Tensor.h>
44
#include <c10/core/Storage.h>
55
#include <c10/core/TensorImpl.h>
6-
#include <torch/csrc/lazy/backend/backend_interface.h>
7-
#include <torch/csrc/lazy/core/config.h>
8-
#include <torch/csrc/lazy/core/ir.h>
9-
#include <torch/csrc/lazy/core/trie.h>
106

117
#include "torch_xla/csrc/tensor.h"
12-
#include "torch_xla/csrc/xla_backend_impl.h"
138

149
namespace torch_xla {
1510

0 commit comments

Comments
 (0)