-
Notifications
You must be signed in to change notification settings - Fork 698
[GraphOptimizer] Extend Reshape sinking pass for binary eltwise ops #5715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1143,37 +1143,122 @@ bool SinkReshapes::run(Function *F, const CompilationContext &cctx) { | |||||||||||||||||||||||||||||||||||||||||||||||||
auto *node = &N; | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Sink Reshape below eltwise nodes. | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (!node->isDataParallel() || node->hasSideEffects()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (node->isDataParallel() && !node->hasSideEffects()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
// Unary eltwise nodes. | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (node->getNumInputs() == 1 && node->getNumResults() == 1) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() || | ||||||||||||||||||||||||||||||||||||||||||||||||||
node->getNthInput(0).getType()->isFusedQuantizedType() || | ||||||||||||||||||||||||||||||||||||||||||||||||||
node->getNthResult(0).dims().equals(node->getNthInput(0).dims())) | ||||||||||||||||||||||||||||||||||||||||||||||||||
<< "SinkReshapes: not an element-wise node: " << node->toString(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0)); | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (!RS) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Unary eltwise nodes. | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (node->getNumInputs() != 1 || node->getNumResults() != 1) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
// Create new eltwise node. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto in = RS->getInput(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto out = node->getNthResult(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto newTy = | ||||||||||||||||||||||||||||||||||||||||||||||||||
F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *newN = F->addNode(node->clone()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setNthInput(0, in); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setTypeUnsafe(0, newTy); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setPredicate(node->getPredicate()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Create new Reshape. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *newRS = F->createReshape(RS->getName(), newN, | ||||||||||||||||||||||||||||||||||||||||||||||||||
RS->getResult().getType()->dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newRS->setPredicate(node->getPredicate()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
out.replaceAllUsesOfWith(newRS->getResult()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0)); | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (!RS) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
changed = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Binary eltwise nodes. | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (node->getNumInputs() == 2 && node->getNumResults() == 1) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() || | ||||||||||||||||||||||||||||||||||||||||||||||||||
node->getNthInput(0).getType()->isFusedQuantizedType() || | ||||||||||||||||||||||||||||||||||||||||||||||||||
node->getNthResult(0).dims().equals(node->getNthInput(0).dims())) | ||||||||||||||||||||||||||||||||||||||||||||||||||
<< "SinkReshapes: not an element-wise node: " << node->toString(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() || | ||||||||||||||||||||||||||||||||||||||||||||||||||
node->getNthInput(1).getType()->isFusedQuantizedType() || | ||||||||||||||||||||||||||||||||||||||||||||||||||
node->getNthResult(0).dims().equals(node->getNthInput(1).dims())) | ||||||||||||||||||||||||||||||||||||||||||||||||||
<< "SinkReshapes: not an element-wise node: " << node->toString(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// At least one of the inputs must be a Reshape. | ||||||||||||||||||||||||||||||||||||||||||||||||||
// If both inputs are Reshapes, they must have the same dimensions. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *LRN = dyn_cast<ReshapeNode>(node->getNthInput(0)); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *RRN = dyn_cast<ReshapeNode>(node->getNthInput(1)); | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (!LRN && !RRN) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (LRN && RRN && | ||||||||||||||||||||||||||||||||||||||||||||||||||
!LRN->getResult().dims().equals(RRN->getResult().dims())) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Create new eltwise node. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto in = RS->getInput(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto out = node->getNthResult(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto newTy = | ||||||||||||||||||||||||||||||||||||||||||||||||||
F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *newN = F->addNode(node->clone()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setNthInput(0, in); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setTypeUnsafe(0, newTy); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setPredicate(node->getPredicate()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
// Canonicalize node to simplify transformation implementation (make LHS | ||||||||||||||||||||||||||||||||||||||||||||||||||
// always be the input with a Reshape). | ||||||||||||||||||||||||||||||||||||||||||||||||||
bool swap = (LRN == nullptr); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto nv = node->getNthInput(1); | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (swap) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
nv = node->getNthInput(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||
LRN = RRN; | ||||||||||||||||||||||||||||||||||||||||||||||||||
RRN = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Create new Reshape. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *newRS = F->createReshape(RS->getName(), newN, | ||||||||||||||||||||||||||||||||||||||||||||||||||
RS->getResult().getType()->dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newRS->setPredicate(node->getPredicate()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
out.replaceAllUsesOfWith(newRS->getResult()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *RQ = dyn_cast<QuantizeNode>(nv); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *RS = dyn_cast<SplatNode>(nv); | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (!RRN && !RC && !RS) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1211
to
+1217
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combined with the below suggestion of not relying on specific nodes -- can we just check if this is a constant chain of operations and fold if so? I.e. we could make
Suggested change
(We'd need to do something like assume Interpreter is the constant folding backend, but I think that's generally the case anyway...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use |
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
changed = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||
// Create new Constant, Quantize or Splat, if needed. | ||||||||||||||||||||||||||||||||||||||||||||||||||
NodeValue rhs; | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (RRN) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
rhs = RRN->getInput(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (RC) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto ty = F->getParent()->uniqueTypeWithNewShape( | ||||||||||||||||||||||||||||||||||||||||||||||||||
RC->getType(), LRN->getInput().dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *newC = F->getParent()->createConstant(ty, RC->getName()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newC->getPayloadMutable().copyRawFrom(&RC->getPayload()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
rhs = newC->getOutput(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (RQ) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto ty = F->getParent()->uniqueTypeWithNewShape( | ||||||||||||||||||||||||||||||||||||||||||||||||||
RQ->getResult().getType(), LRN->getInput().dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
rhs = F->createQuantize(RQ->getName(), rhs, ty); | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (RS) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto ty = F->getParent()->uniqueTypeWithNewShape( | ||||||||||||||||||||||||||||||||||||||||||||||||||
RS->getResult().getType(), LRN->getInput().dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
rhs = F->createSplat(RS->getName(), ty, RS->getValue()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1219
to
+1240
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All this code seems unnecessarily complex. We have optimizations that do folding of Reshapes into Constants/Splats, and if we have a
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Create new eltwise node. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto lhs = LRN->getInput(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto out = node->getNthResult(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto newTy = | ||||||||||||||||||||||||||||||||||||||||||||||||||
F->getParent()->uniqueTypeWithNewShape(out.getType(), lhs.dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *newN = F->addNode(node->clone()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setNthInput(0, swap ? rhs : lhs); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setNthInput(1, swap ? lhs : rhs); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setTypeUnsafe(0, newTy); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newN->setPredicate(node->getPredicate()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// Create new Reshape. | ||||||||||||||||||||||||||||||||||||||||||||||||||
auto *newRN = F->createReshape(LRN->getName(), newN, | ||||||||||||||||||||||||||||||||||||||||||||||||||
LRN->getResult().getType()->dims()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
newRN->setPredicate(node->getPredicate()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
out.replaceAllUsesOfWith(newRN->getResult()); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
changed = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
return changed; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -673,7 +673,6 @@ int main(int argc, char **argv) { | |
.addResultFromCtorArg() | ||
.dataParallel() | ||
.setDocstring("Performs element-wise exponential to the Input."); | ||
// clang-format on | ||
|
||
BB.newNode("Logit") | ||
.addInput("Input") | ||
|
@@ -685,8 +684,8 @@ int main(int argc, char **argv) { | |
BB.newNode("NonZero") | ||
.addInput("Cond") | ||
.addResultFromCtorArg() | ||
.dataParallel() | ||
.setDocstring("Selects indices of the true elements in Cond"); | ||
// clang-format on | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Assuming you're doing this because |
||
|
||
BB.newNode("Select") | ||
.addInput("Cond") | ||
|
@@ -1432,7 +1431,6 @@ int main(int argc, char **argv) { | |
.addInput("Input") | ||
.addInput("Mapping") | ||
.addResultFromCtorArg() | ||
.dataParallel() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
.setDocstring("Simple mapping between quantized numbers." | ||
"This can be used as quantized sigmoid or tanh functions."); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we leave this as an early continue? Looks like there's no logic after the
if
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The plan is to extend this pass further with other nodes (non data parallel). There is more reshape sinking code in Cadence we're planning to upstream eventually, plus I was planning to move existing reshape related code from
sinkCode
here.