-
Notifications
You must be signed in to change notification settings - Fork 699
[Placeholder] Allow the differentiation of Placeholder nodes. #1612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7a97d67
to
32e009e
Compare
include/glow/Graph/Nodes.h
Outdated
@@ -31,8 +31,12 @@ namespace glow { | |||
// Storage is the base class for Variables, which are bound to tensors, and | |||
// Placeholder nodes which are unbound. | |||
class Storage : public Node { | |||
/// Specifies if the variable is trainable. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
include/glow/Graph/Nodes.h
Outdated
@@ -50,6 +54,9 @@ class Storage : public Node { | |||
Node *clone() const; | |||
/// @} | |||
|
|||
/// \returns True if the Variable is initialized to be in training mode. | |||
bool isTraining() const { return isTrainable_; } |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -144,8 +144,8 @@ class ChildMemSizeBasedScheduler : public Scheduler { | |||
// We don't model memory dependencies, but we still need to honor them. | |||
// Make sure the SaveNode happens after the last use of the output variable. | |||
if (auto *save = dyn_cast<SaveNode>(N)) { | |||
Variable *output = save->getVariable(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
auto *input1 = mod.createPlaceholder(ty, "input1"); | ||
auto *input2 = mod.createPlaceholder(ty, "input2"); | ||
auto *input3 = mod.createPlaceholder(ty, "input3"); | ||
auto *input1 = mod.createPlaceholder(ty, "input1", false); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
// Check that the Placeholder has multiple users, because at least one write | ||
/// node will be added. | ||
EXPECT_GE(A->getNumUsers(), 1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Add the isTrain flag to placeholders. As we are working to get rid of mutable variables, this flag will allow us to convert all of our tests from Variable to Placeholder without rewriting all of our training tests.
Allow differentiation of Placeholder variables. This commit fixes a few bugs in the scheduler, cloner and differentiation function, where we assumed that the inputs to the graph are variables, and not placeholders. This commit adds a unit test that checks that we can differentiate (and later compile) functions that differentiate (and update) placeholders.
32e009e
to
48adbe7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for the PR.
But now we need to go over all places where getVars()
is used and think if we should also consider placeholders there. E.g. graph dump, Quantization, GraphOptimizer.cpp
definitely need to consider them.
@artemrakhov Yes! That's a good point. Thank you. |
Description: Allow the differentiation of Placeholder nodes. This PR includes two main parts. First, we add a flag to Placeholder to match the Variable node and allow the differentiation of the selected nodes. The second part fixes a few places (scheduler, cloner, etc) that assume that all storage nodes are Variable.
Testing: Added unit tests and ran ninja check.
Documentation: None.
This is related to #1334. This PR is a step in the direction of eliminating the Variables, and is required for porting our unit tests to Placeholders instead of training mutable variables.