From 9f9f57f362d616509bc392cef81544f301649c03 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 11 Dec 2018 14:22:32 -0800 Subject: [PATCH 1/3] [graph] Method to test if placeholder is an output --- include/glow/Graph/Nodes.h | 3 +++ lib/Graph/Nodes.cpp | 9 +++++++++ tests/unittests/graphTest.cpp | 14 ++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/include/glow/Graph/Nodes.h b/include/glow/Graph/Nodes.h index 3e5467a0be..0a2d530d37 100644 --- a/include/glow/Graph/Nodes.h +++ b/include/glow/Graph/Nodes.h @@ -123,6 +123,9 @@ class Placeholder : public Storage { /// differentiation. bool isTraining() const { return isTrainable_; } + /// \returns True if the placeholder is a Function output. + bool isOutput() const; + static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::PlaceholderKind; } diff --git a/lib/Graph/Nodes.cpp b/lib/Graph/Nodes.cpp index f5c01d98e0..bd804d8a2e 100644 --- a/lib/Graph/Nodes.cpp +++ b/lib/Graph/Nodes.cpp @@ -35,6 +35,15 @@ llvm::hash_code Placeholder::getHash() const { return llvm::hash_combine(getName()); } +bool Placeholder::isOutput() const { + for (auto const &use : getUsers()) { + if (llvm::isa(use.getUser())) { + return true; + } + } + return false; +} + //===----------------------------------------------------------------------===// // Visitor methods //===----------------------------------------------------------------------===// diff --git a/tests/unittests/graphTest.cpp b/tests/unittests/graphTest.cpp index dca1dd49e2..f79e2d07d5 100644 --- a/tests/unittests/graphTest.cpp +++ b/tests/unittests/graphTest.cpp @@ -1202,3 +1202,17 @@ TEST(Graph, hookTest) { ASSERT_TRUE(ph); ASSERT_EQ(ph, in); } + +/// Check that output placeholders can be identified. +TEST(Graph, outputPlaceholderTest) { + Module mod; + auto *F = mod.createFunction("main"); + auto *input = + mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "input", false); + auto *weights = mod.createConstant(ElemKind::FloatTy, {4, 4}, "weights"); + auto *bias = mod.createConstant(ElemKind::FloatTy, {4}, "bias"); + auto *FC = F->createFullyConnected("fc", input, weights, bias); + auto *save = F->createSave("save", FC); + EXPECT_TRUE(save->getPlaceholder()->isOutput()); + EXPECT_FALSE(input->isOutput()); +} From 3dacd168ac8218847ef9291f8a1009f1c8af70ee Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 11 Dec 2018 16:30:42 -0800 Subject: [PATCH 2/3] Test against a specific function --- include/glow/Graph/Nodes.h | 4 ++-- lib/Graph/Nodes.cpp | 5 +++-- tests/unittests/graphTest.cpp | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/include/glow/Graph/Nodes.h b/include/glow/Graph/Nodes.h index 0a2d530d37..c128167f21 100644 --- a/include/glow/Graph/Nodes.h +++ b/include/glow/Graph/Nodes.h @@ -123,8 +123,8 @@ class Placeholder : public Storage { /// differentiation. bool isTraining() const { return isTrainable_; } - /// \returns True if the placeholder is a Function output. - bool isOutput() const; + /// \returns True if this placeholder is an output of \p F. + bool isOutput(Function *F) const; static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::PlaceholderKind; diff --git a/lib/Graph/Nodes.cpp b/lib/Graph/Nodes.cpp index bd804d8a2e..1608a905ce 100644 --- a/lib/Graph/Nodes.cpp +++ b/lib/Graph/Nodes.cpp @@ -35,9 +35,10 @@ llvm::hash_code Placeholder::getHash() const { return llvm::hash_combine(getName()); } -bool Placeholder::isOutput() const { +bool Placeholder::isOutput(Function *F) const { for (auto const &use : getUsers()) { - if (llvm::isa(use.getUser())) { + auto *user = use.getUser(); + if (llvm::isa(user) && (user->getParent() == F)) { return true; } } diff --git a/tests/unittests/graphTest.cpp b/tests/unittests/graphTest.cpp index f79e2d07d5..50fb721c6b 100644 --- a/tests/unittests/graphTest.cpp +++ b/tests/unittests/graphTest.cpp @@ -1213,6 +1213,6 @@ TEST(Graph, outputPlaceholderTest) { auto *bias = mod.createConstant(ElemKind::FloatTy, {4}, "bias"); auto *FC = F->createFullyConnected("fc", input, weights, bias); auto *save = F->createSave("save", FC); - EXPECT_TRUE(save->getPlaceholder()->isOutput()); - EXPECT_FALSE(input->isOutput()); + EXPECT_TRUE(save->getPlaceholder()->isOutput(F)); + EXPECT_FALSE(input->isOutput(F)); } From fe2af3df62fc05390402643dbf5736568a1a6fac Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 11 Dec 2018 16:31:33 -0800 Subject: [PATCH 3/3] negative tests too --- tests/unittests/graphTest.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unittests/graphTest.cpp b/tests/unittests/graphTest.cpp index 50fb721c6b..29193ae37a 100644 --- a/tests/unittests/graphTest.cpp +++ b/tests/unittests/graphTest.cpp @@ -1207,6 +1207,7 @@ TEST(Graph, hookTest) { TEST(Graph, outputPlaceholderTest) { Module mod; auto *F = mod.createFunction("main"); + auto *G = mod.createFunction("nope"); auto *input = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "input", false); auto *weights = mod.createConstant(ElemKind::FloatTy, {4, 4}, "weights"); @@ -1215,4 +1216,6 @@ TEST(Graph, outputPlaceholderTest) { auto *save = F->createSave("save", FC); EXPECT_TRUE(save->getPlaceholder()->isOutput(F)); EXPECT_FALSE(input->isOutput(F)); + EXPECT_FALSE(save->getPlaceholder()->isOutput(G)); + EXPECT_FALSE(input->isOutput(G)); }