Skip to content

[GraphOptimizer] Extend Reshape sinking pass for binary eltwise ops #5715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/Graph/TensorLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ static bool acceptsAnyInputLayout(const glow::Node *node) {
case Kinded::Kind::BatchNormalizationNodeKind:
case Kinded::Kind::BatchNormalizationGradNodeKind:
case Kinded::Kind::PadNodeKind:
case Kinded::Kind::NonZeroNodeKind:
case Kinded::Kind::IntLookupTableNodeKind:
case Kinded::Kind::ReshapeNodeKind:
case Kinded::Kind::MeanVarNormalizationNodeKind:
case Kinded::Kind::MatMulNodeKind:
Expand Down
137 changes: 111 additions & 26 deletions lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,37 +1143,122 @@ bool SinkReshapes::run(Function *F, const CompilationContext &cctx) {
auto *node = &N;

// Sink Reshape below eltwise nodes.
if (!node->isDataParallel() || node->hasSideEffects()) {
continue;
}
if (node->isDataParallel() && !node->hasSideEffects()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we leave this as an early continue? Looks like there's no logic after the if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is to extend this pass further with other nodes (non data parallel). There is more reshape sinking code in Cadence we're planning to upstream eventually, plus I was planning to move existing reshape related code from sinkCode here.

// Unary eltwise nodes.
if (node->getNumInputs() == 1 && node->getNumResults() == 1) {
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() ||
node->getNthInput(0).getType()->isFusedQuantizedType() ||
node->getNthResult(0).dims().equals(node->getNthInput(0).dims()))
<< "SinkReshapes: not an element-wise node: " << node->toString();

auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0));
if (!RS) {
continue;
}

// Unary eltwise nodes.
if (node->getNumInputs() != 1 || node->getNumResults() != 1) {
continue;
}
// Create new eltwise node.
auto in = RS->getInput();
auto out = node->getNthResult(0);
auto newTy =
F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims());
auto *newN = F->addNode(node->clone());
newN->setNthInput(0, in);
newN->setTypeUnsafe(0, newTy);
newN->setPredicate(node->getPredicate());

// Create new Reshape.
auto *newRS = F->createReshape(RS->getName(), newN,
RS->getResult().getType()->dims());
newRS->setPredicate(node->getPredicate());
out.replaceAllUsesOfWith(newRS->getResult());

auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0));
if (!RS) {
continue;
}
changed = true;
}

// Binary eltwise nodes.
if (node->getNumInputs() == 2 && node->getNumResults() == 1) {
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() ||
node->getNthInput(0).getType()->isFusedQuantizedType() ||
node->getNthResult(0).dims().equals(node->getNthInput(0).dims()))
<< "SinkReshapes: not an element-wise node: " << node->toString();
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() ||
node->getNthInput(1).getType()->isFusedQuantizedType() ||
node->getNthResult(0).dims().equals(node->getNthInput(1).dims()))
<< "SinkReshapes: not an element-wise node: " << node->toString();

// At least one of the inputs must be a Reshape.
// If both inputs are Reshapes, they must have the same dimensions.
auto *LRN = dyn_cast<ReshapeNode>(node->getNthInput(0));
auto *RRN = dyn_cast<ReshapeNode>(node->getNthInput(1));
if (!LRN && !RRN) {
continue;
}
if (LRN && RRN &&
!LRN->getResult().dims().equals(RRN->getResult().dims())) {
continue;
}

// Create new eltwise node.
auto in = RS->getInput();
auto out = node->getNthResult(0);
auto newTy =
F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims());
auto *newN = F->addNode(node->clone());
newN->setNthInput(0, in);
newN->setTypeUnsafe(0, newTy);
newN->setPredicate(node->getPredicate());
// Canonicalize node to simplify transformation implementation (make LHS
// always be the input with a Reshape).
bool swap = (LRN == nullptr);
auto nv = node->getNthInput(1);
if (swap) {
nv = node->getNthInput(0);
LRN = RRN;
RRN = nullptr;
}

// Create new Reshape.
auto *newRS = F->createReshape(RS->getName(), newN,
RS->getResult().getType()->dims());
newRS->setPredicate(node->getPredicate());
out.replaceAllUsesOfWith(newRS->getResult());
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat.
auto *RQ = dyn_cast<QuantizeNode>(nv);
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv);
auto *RS = dyn_cast<SplatNode>(nv);
if (!RRN && !RC && !RS) {
continue;
}
Comment on lines +1211 to +1217
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined with the below suggestion of not relying on specific nodes -- can we just check if this is a constant chain of operations and fold if so? I.e. we could make isConstantOperation from ConstantFolding.cpp usable outside of there and call it here, e.g.

Suggested change
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat.
auto *RQ = dyn_cast<QuantizeNode>(nv);
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv);
auto *RS = dyn_cast<SplatNode>(nv);
if (!RRN && !RC && !RS) {
continue;
}
// RHS must be a constant chain if it's not a reshape, to allow for the reshape to get folded into the chain later on.
if (!RRN || !isConstantOperation(nv)) {
continue;
}

(We'd need to do something like assume Interpreter is the constant folding backend, but I think that's generally the case anyway...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use isConstantOperation, we'll change the algorithm from O(N) to something like O(N^2). But the optimization should indeed benefit from it. Is that ok?


changed = true;
// Create new Constant, Quantize or Splat, if needed.
NodeValue rhs;
if (RRN) {
rhs = RRN->getInput();
}
if (RC) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RC->getType(), LRN->getInput().dims());
auto *newC = F->getParent()->createConstant(ty, RC->getName());
newC->getPayloadMutable().copyRawFrom(&RC->getPayload());
rhs = newC->getOutput();
}
if (RQ) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RQ->getResult().getType(), LRN->getInput().dims());
rhs = F->createQuantize(RQ->getName(), rhs, ty);
}
if (RS) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RS->getResult().getType(), LRN->getInput().dims());
rhs = F->createSplat(RS->getName(), ty, RS->getValue());
}
Comment on lines +1219 to +1240
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this code seems unnecessarily complex. We have optimizations that do folding of Reshapes into Constants/Splats, and if we have a Constant -> Quantize -> Reshape then it'll get constant folded too. Can we not just add a Reshape here and let other optimizations clean it up later? E.g.

Suggested change
// Create new Constant, Quantize or Splat, if needed.
NodeValue rhs;
if (RRN) {
rhs = RRN->getInput();
}
if (RC) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RC->getType(), LRN->getInput().dims());
auto *newC = F->getParent()->createConstant(ty, RC->getName());
newC->getPayloadMutable().copyRawFrom(&RC->getPayload());
rhs = newC->getOutput();
}
if (RQ) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RQ->getResult().getType(), LRN->getInput().dims());
rhs = F->createQuantize(RQ->getName(), rhs, ty);
}
if (RS) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
RS->getResult().getType(), LRN->getInput().dims());
rhs = F->createSplat(RS->getName(), ty, RS->getValue());
}
// Reshape RHS to match LHS.
NodeValue rhs = F->createSplat(nv.getNode()->getName(), nv, LRN->getInput().dims());


// Create new eltwise node.
auto lhs = LRN->getInput();
auto out = node->getNthResult(0);
auto newTy =
F->getParent()->uniqueTypeWithNewShape(out.getType(), lhs.dims());
auto *newN = F->addNode(node->clone());
newN->setNthInput(0, swap ? rhs : lhs);
newN->setNthInput(1, swap ? lhs : rhs);
newN->setTypeUnsafe(0, newTy);
newN->setPredicate(node->getPredicate());

// Create new Reshape.
auto *newRN = F->createReshape(LRN->getName(), newN,
LRN->getResult().getType()->dims());
newRN->setPredicate(node->getPredicate());
out.replaceAllUsesOfWith(newRN->getResult());

changed = true;
}
}
}
return changed;
}
Expand Down
82 changes: 70 additions & 12 deletions tests/unittests/GraphOptzTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2323,21 +2323,23 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
const dim_t reshape[] = {1, 6000};
Type t1(ElemKind::FloatTy, shape);
Type t2(ElemKind::FloatTy, reshape);
Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
"input", true);
Node *input1 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
"input1", true);
Node *input2 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, reshape,
"input2", true);
auto *Z1 = F_->createSplat("zero1", &t1, 1.5);
auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input, Z1);
auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input1, Z1);
auto *R1 = F_->createReshape("reshape1", Z1, reshape);
// Z1 is used by R1 and A1.
// The reshape optimization will thus NOT be able to remove this reshape node
// (R1).
auto *R2 = F_->createReshape("reshape2", A1, reshape);
auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, R2);
F_->createSave("save", A1);
auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, input2);
auto *Z2 = F_->createSplat("zero2", &t1, 2.5);
auto *R3 = F_->createReshape("reshape3", Z2, reshape);
// Z2 is only used by R3.
// The Z2,R3 nodes will be replaced by a new splat node with the shape of R3.
auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R3);
auto *R2 = F_->createReshape("reshape3", Z2, reshape);
// Z2 is only used by R2.
// The Z2,R2 nodes will be replaced by a new splat node with the shape of R2.
auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R2);
auto *O = F_->createSave("ret", A3);

// Before optimization, we have 9 nodes in the graph.
Expand All @@ -2352,7 +2354,7 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
// replace by a new splat node.
EXPECT_EQ(F_->getNodes().size(), 8);

// The second input of A3 shoule be a splat node with a shape of R3.
// The second input of A3 shoule be a splat node with a shape of R2.
auto *newA3 = llvm::dyn_cast<AddNode>(O->getInput());
ASSERT_TRUE(newA3);
auto *SN = llvm::dyn_cast<SplatNode>(newA3->getRHS());
Expand All @@ -2362,8 +2364,8 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
// R1 should still be in the graph.
EXPECT_TRUE(functionContainsNode(F_, R1));

// R3 and Z2 should not be in the graph any more.
EXPECT_FALSE(functionContainsNode(F_, R3));
// R2 and Z2 should not be in the graph any more.
EXPECT_FALSE(functionContainsNode(F_, R2));
EXPECT_FALSE(functionContainsNode(F_, Z2));
}

Expand Down Expand Up @@ -8043,6 +8045,62 @@ TEST_F(GraphOptz, SinkReshapeBelowUnaryEltwiseOps) {
checkNumericalEquivalence(0.f);
}

TEST_F(GraphOptz, SinkReshapeBelowBinaryEltwiseOps) {
const dim_t dimsIn[] = {10, 10};
const dim_t dimsOut[] = {5, 5, 4};

// Prepare inputs.
auto *in1 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.12f, 0,
"in1", false);
auto *in2 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.17f, 0,
"in2", false);
auto *QCN =
mod_.createConstant(ElemKind::Int8QTy, dimsOut, 0.13f, 0, "quant_const");
auto *FCN = mod_.createConstant(ElemKind::FloatTy, dimsOut, "float_const");
auto qTy = mod_.uniqueType(ElemKind::Int8QTy, dimsOut, 0.15f, 0);
auto *QN = F_->createQuantize("quantize", FCN, qTy);
auto *SN = F_->createSplat("splat", qTy, 1.79f);
QCN->getHandle<int8_t>().randomize(-128, 127, mod_.getPRNG());
FCN->getHandle<float>().randomize(-1.f, 2.f, mod_.getPRNG());

// Test different combinations of Reshape, Constant, Quantize, Splat passed as
// LHS or RHS.
auto *RN1 = F_->createReshape("reshape", in1, dimsOut);
auto *RN2 = F_->createReshape("reshape", in2, dimsOut);
auto *AN = F_->createAdd("add", RN1, RN2);
auto *MLN = F_->createMul("mul", AN, QCN);
auto *MXN = F_->createMax("max", QN, MLN);
auto *SBN = F_->createSub("sub", MXN, SN);
auto *save = F_->createSave("ret", SBN);

optimizedF_ = optimizeFunctionForTest(F_);

auto *optSave =
llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
ASSERT_TRUE(optSave);
auto *optRN = llvm::dyn_cast<ReshapeNode>(optSave->getInput());
ASSERT_TRUE(optRN);
EXPECT_EQ(optRN->getResult().dims(), llvm::makeArrayRef(dimsOut));
auto *optSBN = llvm::dyn_cast<SubNode>(optRN->getInput());
ASSERT_TRUE(optSBN);
EXPECT_EQ(optSBN->getResult().dims(), llvm::makeArrayRef(dimsIn));
auto *optMXN = llvm::dyn_cast<MaxNode>(optSBN->getLHS());
ASSERT_TRUE(optMXN);
EXPECT_EQ(optMXN->getResult().dims(), llvm::makeArrayRef(dimsIn));
auto *optMLN = llvm::dyn_cast<MulNode>(optMXN->getRHS());
ASSERT_TRUE(optMLN);
EXPECT_EQ(optMLN->getResult().dims(), llvm::makeArrayRef(dimsIn));
auto *optAN = llvm::dyn_cast<AddNode>(optMLN->getLHS());
ASSERT_TRUE(optAN);
EXPECT_EQ(optAN->getResult().dims(), llvm::makeArrayRef(dimsIn));

bindings_.allocate(in1)->getHandle<int8_t>().randomize(-128, 127,
mod_.getPRNG());
bindings_.allocate(in2)->getHandle<int8_t>().randomize(-128, 127,
mod_.getPRNG());
checkNumericalEquivalence(0.f);
}

TEST_F(GraphOptz, OptConvertToDequantize) {
auto *I =
mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false);
Expand Down
4 changes: 1 addition & 3 deletions tools/ClassGen/NodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,6 @@ int main(int argc, char **argv) {
.addResultFromCtorArg()
.dataParallel()
.setDocstring("Performs element-wise exponential to the Input.");
// clang-format on

BB.newNode("Logit")
.addInput("Input")
Expand All @@ -685,8 +684,8 @@ int main(int argc, char **argv) {
BB.newNode("NonZero")
.addInput("Cond")
.addResultFromCtorArg()
.dataParallel()
.setDocstring("Selects indices of the true elements in Cond");
// clang-format on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Assuming you're doing this because NonZero is seeing weird formatting when dataParallel is removed -- can we keep the original // clang-format on after Exp and add // clang-format off just above NonZero here? Otherwise we're wrapping Logit too unnecessarily..


BB.newNode("Select")
.addInput("Cond")
Expand Down Expand Up @@ -1432,7 +1431,6 @@ int main(int argc, char **argv) {
.addInput("Input")
.addInput("Mapping")
.addResultFromCtorArg()
.dataParallel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IntLookupTable is not really a data parallel operation. We need the whole Mapping input in order to produce a single output element.

.setDocstring("Simple mapping between quantized numbers."
"This can be used as quantized sigmoid or tanh functions.");

Expand Down