-
Notifications
You must be signed in to change notification settings - Fork 699
[GraphOptimizer] Extend Reshape sinking pass for binary eltwise ops #5715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GraphOptimizer] Extend Reshape sinking pass for binary eltwise ops #5715
Conversation
22139d9
to
bbc3eab
Compare
This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions. |
Ping |
Ping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, but added some comments about making it general, and less complicated.
@@ -1432,7 +1431,6 @@ int main(int argc, char **argv) { | |||
.addInput("Input") | |||
.addInput("Mapping") | |||
.addResultFromCtorArg() | |||
.dataParallel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IntLookupTable
is not really a data parallel operation. We need the whole Mapping
input in order to produce a single output element.
.setDocstring("Selects indices of the true elements in Cond"); | ||
// clang-format on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Assuming you're doing this because NonZero
is seeing weird formatting when dataParallel is removed -- can we keep the original // clang-format on
after Exp
and add // clang-format off
just above NonZero
here? Otherwise we're wrapping Logit
too unnecessarily..
if (!node->isDataParallel() || node->hasSideEffects()) { | ||
continue; | ||
} | ||
if (node->isDataParallel() && !node->hasSideEffects()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we leave this as an early continue? Looks like there's no logic after the if
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The plan is to extend this pass further with other nodes (non data parallel). There is more reshape sinking code in Cadence we're planning to upstream eventually, plus I was planning to move existing reshape related code from sinkCode
here.
// Create new Constant, Quantize or Splat, if needed. | ||
NodeValue rhs; | ||
if (RRN) { | ||
rhs = RRN->getInput(); | ||
} | ||
if (RC) { | ||
auto ty = F->getParent()->uniqueTypeWithNewShape( | ||
RC->getType(), LRN->getInput().dims()); | ||
auto *newC = F->getParent()->createConstant(ty, RC->getName()); | ||
newC->getPayloadMutable().copyRawFrom(&RC->getPayload()); | ||
rhs = newC->getOutput(); | ||
} | ||
if (RQ) { | ||
auto ty = F->getParent()->uniqueTypeWithNewShape( | ||
RQ->getResult().getType(), LRN->getInput().dims()); | ||
rhs = F->createQuantize(RQ->getName(), rhs, ty); | ||
} | ||
if (RS) { | ||
auto ty = F->getParent()->uniqueTypeWithNewShape( | ||
RS->getResult().getType(), LRN->getInput().dims()); | ||
rhs = F->createSplat(RS->getName(), ty, RS->getValue()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All this code seems unnecessarily complex. We have optimizations that do folding of Reshapes into Constants/Splats, and if we have a Constant -> Quantize -> Reshape
then it'll get constant folded too. Can we not just add a Reshape here and let other optimizations clean it up later? E.g.
// Create new Constant, Quantize or Splat, if needed. | |
NodeValue rhs; | |
if (RRN) { | |
rhs = RRN->getInput(); | |
} | |
if (RC) { | |
auto ty = F->getParent()->uniqueTypeWithNewShape( | |
RC->getType(), LRN->getInput().dims()); | |
auto *newC = F->getParent()->createConstant(ty, RC->getName()); | |
newC->getPayloadMutable().copyRawFrom(&RC->getPayload()); | |
rhs = newC->getOutput(); | |
} | |
if (RQ) { | |
auto ty = F->getParent()->uniqueTypeWithNewShape( | |
RQ->getResult().getType(), LRN->getInput().dims()); | |
rhs = F->createQuantize(RQ->getName(), rhs, ty); | |
} | |
if (RS) { | |
auto ty = F->getParent()->uniqueTypeWithNewShape( | |
RS->getResult().getType(), LRN->getInput().dims()); | |
rhs = F->createSplat(RS->getName(), ty, RS->getValue()); | |
} | |
// Reshape RHS to match LHS. | |
NodeValue rhs = F->createSplat(nv.getNode()->getName(), nv, LRN->getInput().dims()); |
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat. | ||
auto *RQ = dyn_cast<QuantizeNode>(nv); | ||
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv); | ||
auto *RS = dyn_cast<SplatNode>(nv); | ||
if (!RRN && !RC && !RS) { | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combined with the below suggestion of not relying on specific nodes -- can we just check if this is a constant chain of operations and fold if so? I.e. we could make isConstantOperation
from ConstantFolding.cpp
usable outside of there and call it here, e.g.
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat. | |
auto *RQ = dyn_cast<QuantizeNode>(nv); | |
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv); | |
auto *RS = dyn_cast<SplatNode>(nv); | |
if (!RRN && !RC && !RS) { | |
continue; | |
} | |
// RHS must be a constant chain if it's not a reshape, to allow for the reshape to get folded into the chain later on. | |
if (!RRN || !isConstantOperation(nv)) { | |
continue; | |
} |
(We'd need to do something like assume Interpreter is the constant folding backend, but I think that's generally the case anyway...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use isConstantOperation
, we'll change the algorithm from O(N) to something like O(N^2). But the optimization should indeed benefit from it. Is that ok?
This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions. |
@jfix71, can you comment on the following before I proceed with updating the changes?
|
This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions. |
This PR has been automatically closed due to being stale for 15 days. Thank you for your contributions and feel free to reopen it in case of further progress. |
Summary:
Extend Reshape sinking pass for binary eltwise ops in GraphOptimizer.
Documentation:
N/A
Fixes #5247
Test Plan:
Added GraphOptz unit test.