Skip to content

[graph] Method to test if placeholder is an output #2160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/glow/Graph/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class Placeholder : public Storage {
/// differentiation.
bool isTraining() const { return isTrainable_; }

/// \returns True if this placeholder is an output of \p F.
bool isOutput(Function *F) const;
Copy link
Contributor

@rdzhabarov rdzhabarov Dec 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd similarly need isInput().
Probably !isOutput should be good, unless we could make isInput more efficient in implementation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nodes can be inputs and outputs. For example, quantization nodes. How do we handle them?

Copy link
Contributor

@rdzhabarov rdzhabarov Dec 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Profiling nodes are kinda special (similar thing would happen with training placeholders, but we have a way to check isTraining).

  • one simple solution would be to have an enum for type of placeholder { in, out, inout } and return it here.
  • another idea is to force placeholders for profiling nodes to be trainable (since we change values and then re-read them, similar process to training).
  • something else?

First approach seems to be cleaner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option could be to use isOverwrittenNthInput() -- we use it for checking for these other cases (profiling and training), right? So we could also add a check here for if a user is using the Placeholder as an overwritten input. Then we would also need a separate isInput().


static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::PlaceholderKind;
}
Expand Down
10 changes: 10 additions & 0 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ llvm::hash_code Placeholder::getHash() const {
return llvm::hash_combine(getName());
}

bool Placeholder::isOutput(Function *F) const {
for (auto const &use : getUsers()) {
auto *user = use.getUser();
if (llvm::isa<SaveNode>(user) && (user->getParent() == F)) {
return true;
}
}
return false;
}

//===----------------------------------------------------------------------===//
// Visitor methods
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/graphTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1202,3 +1202,20 @@ TEST(Graph, hookTest) {
ASSERT_TRUE(ph);
ASSERT_EQ(ph, in);
}

/// Check that output placeholders can be identified.
TEST(Graph, outputPlaceholderTest) {
Module mod;
auto *F = mod.createFunction("main");
auto *G = mod.createFunction("nope");
auto *input =
mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "input", false);
auto *weights = mod.createConstant(ElemKind::FloatTy, {4, 4}, "weights");
auto *bias = mod.createConstant(ElemKind::FloatTy, {4}, "bias");
auto *FC = F->createFullyConnected("fc", input, weights, bias);
auto *save = F->createSave("save", FC);
EXPECT_TRUE(save->getPlaceholder()->isOutput(F));
EXPECT_FALSE(input->isOutput(F));
EXPECT_FALSE(save->getPlaceholder()->isOutput(G));
EXPECT_FALSE(input->isOutput(G));
}