Skip to content

[GraphOptimizer] Extend Reshape sinking pass for binary eltwise ops #5715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

aturetsk
Copy link
Contributor

Summary:
Extend Reshape sinking pass for binary eltwise ops in GraphOptimizer.

Documentation:
N/A

Fixes #5247

Test Plan:
Added GraphOptz unit test.

@aturetsk
Copy link
Contributor Author

@jfix71, continuing effort on adding generic sinking logic started in #5616.

@aturetsk aturetsk force-pushed the cadence/turetski/reshape-sinking branch from 22139d9 to bbc3eab Compare June 10, 2021 19:53
@stale
Copy link

stale bot commented Jul 1, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions.

@aturetsk
Copy link
Contributor Author

aturetsk commented Jul 1, 2021

Ping

@aturetsk
Copy link
Contributor Author

Ping.

Copy link
Contributor

@jfix71 jfix71 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, but added some comments about making it general, and less complicated.

@@ -1432,7 +1431,6 @@ int main(int argc, char **argv) {
.addInput("Input")
.addInput("Mapping")
.addResultFromCtorArg()
.dataParallel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IntLookupTable is not really a data parallel operation. We need the whole Mapping input in order to produce a single output element.

.setDocstring("Selects indices of the true elements in Cond");
// clang-format on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Assuming you're doing this because NonZero is seeing weird formatting when dataParallel is removed -- can we keep the original // clang-format on after Exp and add // clang-format off just above NonZero here? Otherwise we're wrapping Logit too unnecessarily..

if (!node->isDataParallel() || node->hasSideEffects()) {
continue;
}
if (node->isDataParallel() && !node->hasSideEffects()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we leave this as an early continue? Looks like there's no logic after the if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is to extend this pass further with other nodes (non data parallel). There is more reshape sinking code in Cadence we're planning to upstream eventually, plus I was planning to move existing reshape related code from sinkCode here.

Comment on lines +1219 to +1240
// Create new Constant, Quantize or Splat, if needed.
NodeValue rhs;
if (RRN) {
rhs = RRN->getInput();
}
if (RC) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RC->getType(), LRN->getInput().dims());
auto *newC = F->getParent()->createConstant(ty, RC->getName());
newC->getPayloadMutable().copyRawFrom(&RC->getPayload());
rhs = newC->getOutput();
}
if (RQ) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RQ->getResult().getType(), LRN->getInput().dims());
rhs = F->createQuantize(RQ->getName(), rhs, ty);
}
if (RS) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RS->getResult().getType(), LRN->getInput().dims());
rhs = F->createSplat(RS->getName(), ty, RS->getValue());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this code seems unnecessarily complex. We have optimizations that do folding of Reshapes into Constants/Splats, and if we have a Constant -> Quantize -> Reshape then it'll get constant folded too. Can we not just add a Reshape here and let other optimizations clean it up later? E.g.

Suggested change
// Create new Constant, Quantize or Splat, if needed.
NodeValue rhs;
if (RRN) {
rhs = RRN->getInput();
}
if (RC) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RC->getType(), LRN->getInput().dims());
auto *newC = F->getParent()->createConstant(ty, RC->getName());
newC->getPayloadMutable().copyRawFrom(&RC->getPayload());
rhs = newC->getOutput();
}
if (RQ) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RQ->getResult().getType(), LRN->getInput().dims());
rhs = F->createQuantize(RQ->getName(), rhs, ty);
}
if (RS) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RS->getResult().getType(), LRN->getInput().dims());
rhs = F->createSplat(RS->getName(), ty, RS->getValue());
}
// Reshape RHS to match LHS.
NodeValue rhs = F->createSplat(nv.getNode()->getName(), nv, LRN->getInput().dims());

Comment on lines +1211 to +1217
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat.
auto *RQ = dyn_cast<QuantizeNode>(nv);
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv);
auto *RS = dyn_cast<SplatNode>(nv);
if (!RRN && !RC && !RS) {
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined with the below suggestion of not relying on specific nodes -- can we just check if this is a constant chain of operations and fold if so? I.e. we could make isConstantOperation from ConstantFolding.cpp usable outside of there and call it here, e.g.

Suggested change
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat.
auto *RQ = dyn_cast<QuantizeNode>(nv);
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv);
auto *RS = dyn_cast<SplatNode>(nv);
if (!RRN && !RC && !RS) {
continue;
}
// RHS must be a constant chain if it's not a reshape, to allow for the reshape to get folded into the chain later on.
if (!RRN || !isConstantOperation(nv)) {
continue;
}

(We'd need to do something like assume Interpreter is the constant folding backend, but I think that's generally the case anyway...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use isConstantOperation, we'll change the algorithm from O(N) to something like O(N^2). But the optimization should indeed benefit from it. Is that ok?

@stale
Copy link

stale bot commented Aug 21, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions.

@aturetsk
Copy link
Contributor Author

aturetsk commented Aug 23, 2021

@jfix71, can you comment on the following before I proceed with updating the changes?

If we use isConstantOperation, we'll change the algorithm from O(N) to something like O(N^2). But the optimization should indeed benefit from it. Is that ok?

@stale
Copy link

stale bot commented Mar 2, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions.

@stale
Copy link

stale bot commented Apr 16, 2022

This PR has been automatically closed due to being stale for 15 days. Thank you for your contributions and feel free to reopen it in case of further progress.

@stale stale bot closed this Apr 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[GraphOptimizer] Reshape sinking
3 participants