Skip to content

Commit bbc3eab

Browse files
committed
[GraphOptimizer] Extend Reshape sinking pass for binary eltwise ops.
1 parent dcd7679 commit bbc3eab

File tree

4 files changed

+184
-41
lines changed

4 files changed

+184
-41
lines changed

lib/Graph/TensorLayout.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,8 @@ static bool acceptsAnyInputLayout(const glow::Node *node) {
699699
case Kinded::Kind::BatchNormalizationNodeKind:
700700
case Kinded::Kind::BatchNormalizationGradNodeKind:
701701
case Kinded::Kind::PadNodeKind:
702+
case Kinded::Kind::NonZeroNodeKind:
703+
case Kinded::Kind::IntLookupTableNodeKind:
702704
case Kinded::Kind::ReshapeNodeKind:
703705
case Kinded::Kind::MeanVarNormalizationNodeKind:
704706
case Kinded::Kind::MatMulNodeKind:

lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp

Lines changed: 111 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,37 +1143,122 @@ bool SinkReshapes::run(Function *F, const CompilationContext &cctx) {
11431143
auto *node = &N;
11441144

11451145
// Sink Reshape below eltwise nodes.
1146-
if (!node->isDataParallel() || node->hasSideEffects()) {
1147-
continue;
1148-
}
1146+
if (node->isDataParallel() && !node->hasSideEffects()) {
1147+
// Unary eltwise nodes.
1148+
if (node->getNumInputs() == 1 && node->getNumResults() == 1) {
1149+
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() ||
1150+
node->getNthInput(0).getType()->isFusedQuantizedType() ||
1151+
node->getNthResult(0).dims().equals(node->getNthInput(0).dims()))
1152+
<< "SinkReshapes: not an element-wise node: " << node->toString();
1153+
1154+
auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0));
1155+
if (!RS) {
1156+
continue;
1157+
}
11491158

1150-
// Unary eltwise nodes.
1151-
if (node->getNumInputs() != 1 || node->getNumResults() != 1) {
1152-
continue;
1153-
}
1159+
// Create new eltwise node.
1160+
auto in = RS->getInput();
1161+
auto out = node->getNthResult(0);
1162+
auto newTy =
1163+
F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims());
1164+
auto *newN = F->addNode(node->clone());
1165+
newN->setNthInput(0, in);
1166+
newN->setTypeUnsafe(0, newTy);
1167+
newN->setPredicate(node->getPredicate());
1168+
1169+
// Create new Reshape.
1170+
auto *newRS = F->createReshape(RS->getName(), newN,
1171+
RS->getResult().getType()->dims());
1172+
newRS->setPredicate(node->getPredicate());
1173+
out.replaceAllUsesOfWith(newRS->getResult());
11541174

1155-
auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0));
1156-
if (!RS) {
1157-
continue;
1158-
}
1175+
changed = true;
1176+
}
1177+
1178+
// Binary eltwise nodes.
1179+
if (node->getNumInputs() == 2 && node->getNumResults() == 1) {
1180+
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() ||
1181+
node->getNthInput(0).getType()->isFusedQuantizedType() ||
1182+
node->getNthResult(0).dims().equals(node->getNthInput(0).dims()))
1183+
<< "SinkReshapes: not an element-wise node: " << node->toString();
1184+
DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() ||
1185+
node->getNthInput(1).getType()->isFusedQuantizedType() ||
1186+
node->getNthResult(0).dims().equals(node->getNthInput(1).dims()))
1187+
<< "SinkReshapes: not an element-wise node: " << node->toString();
1188+
1189+
// At least one of the inputs must be a Reshape.
1190+
// If both inputs are Reshapes, they must have the same dimensions.
1191+
auto *LRN = dyn_cast<ReshapeNode>(node->getNthInput(0));
1192+
auto *RRN = dyn_cast<ReshapeNode>(node->getNthInput(1));
1193+
if (!LRN && !RRN) {
1194+
continue;
1195+
}
1196+
if (LRN && RRN &&
1197+
!LRN->getResult().dims().equals(RRN->getResult().dims())) {
1198+
continue;
1199+
}
11591200

1160-
// Create new eltwise node.
1161-
auto in = RS->getInput();
1162-
auto out = node->getNthResult(0);
1163-
auto newTy =
1164-
F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims());
1165-
auto *newN = F->addNode(node->clone());
1166-
newN->setNthInput(0, in);
1167-
newN->setTypeUnsafe(0, newTy);
1168-
newN->setPredicate(node->getPredicate());
1201+
// Canonicalize node to simplify transformation implementation (make LHS
1202+
// always be the input with a Reshape).
1203+
bool swap = (LRN == nullptr);
1204+
auto nv = node->getNthInput(1);
1205+
if (swap) {
1206+
nv = node->getNthInput(0);
1207+
LRN = RRN;
1208+
RRN = nullptr;
1209+
}
11691210

1170-
// Create new Reshape.
1171-
auto *newRS = F->createReshape(RS->getName(), newN,
1172-
RS->getResult().getType()->dims());
1173-
newRS->setPredicate(node->getPredicate());
1174-
out.replaceAllUsesOfWith(newRS->getResult());
1211+
// RHS must be either a Reshape or a Constant (+ Quantize) or Splat.
1212+
auto *RQ = dyn_cast<QuantizeNode>(nv);
1213+
auto *RC = dyn_cast<Constant>(RQ ? RQ->getInput() : nv);
1214+
auto *RS = dyn_cast<SplatNode>(nv);
1215+
if (!RRN && !RC && !RS) {
1216+
continue;
1217+
}
11751218

1176-
changed = true;
1219+
// Create new Constant, Quantize or Splat, if needed.
1220+
NodeValue rhs;
1221+
if (RRN) {
1222+
rhs = RRN->getInput();
1223+
}
1224+
if (RC) {
1225+
auto ty = F->getParent()->uniqueTypeWithNewShape(
1226+
RC->getType(), LRN->getInput().dims());
1227+
auto *newC = F->getParent()->createConstant(ty, RC->getName());
1228+
newC->getPayloadMutable().copyRawFrom(&RC->getPayload());
1229+
rhs = newC->getOutput();
1230+
}
1231+
if (RQ) {
1232+
auto ty = F->getParent()->uniqueTypeWithNewShape(
1233+
RQ->getResult().getType(), LRN->getInput().dims());
1234+
rhs = F->createQuantize(RQ->getName(), rhs, ty);
1235+
}
1236+
if (RS) {
1237+
auto ty = F->getParent()->uniqueTypeWithNewShape(
1238+
RS->getResult().getType(), LRN->getInput().dims());
1239+
rhs = F->createSplat(RS->getName(), ty, RS->getValue());
1240+
}
1241+
1242+
// Create new eltwise node.
1243+
auto lhs = LRN->getInput();
1244+
auto out = node->getNthResult(0);
1245+
auto newTy =
1246+
F->getParent()->uniqueTypeWithNewShape(out.getType(), lhs.dims());
1247+
auto *newN = F->addNode(node->clone());
1248+
newN->setNthInput(0, swap ? rhs : lhs);
1249+
newN->setNthInput(1, swap ? lhs : rhs);
1250+
newN->setTypeUnsafe(0, newTy);
1251+
newN->setPredicate(node->getPredicate());
1252+
1253+
// Create new Reshape.
1254+
auto *newRN = F->createReshape(LRN->getName(), newN,
1255+
LRN->getResult().getType()->dims());
1256+
newRN->setPredicate(node->getPredicate());
1257+
out.replaceAllUsesOfWith(newRN->getResult());
1258+
1259+
changed = true;
1260+
}
1261+
}
11771262
}
11781263
return changed;
11791264
}

tests/unittests/GraphOptzTest.cpp

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,21 +2323,23 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
23232323
const dim_t reshape[] = {1, 6000};
23242324
Type t1(ElemKind::FloatTy, shape);
23252325
Type t2(ElemKind::FloatTy, reshape);
2326-
Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2327-
"input", true);
2326+
Node *input1 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2327+
"input1", true);
2328+
Node *input2 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, reshape,
2329+
"input2", true);
23282330
auto *Z1 = F_->createSplat("zero1", &t1, 1.5);
2329-
auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input, Z1);
2331+
auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input1, Z1);
23302332
auto *R1 = F_->createReshape("reshape1", Z1, reshape);
23312333
// Z1 is used by R1 and A1.
23322334
// The reshape optimization will thus NOT be able to remove this reshape node
23332335
// (R1).
2334-
auto *R2 = F_->createReshape("reshape2", A1, reshape);
2335-
auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, R2);
2336+
F_->createSave("save", A1);
2337+
auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, input2);
23362338
auto *Z2 = F_->createSplat("zero2", &t1, 2.5);
2337-
auto *R3 = F_->createReshape("reshape3", Z2, reshape);
2338-
// Z2 is only used by R3.
2339-
// The Z2,R3 nodes will be replaced by a new splat node with the shape of R3.
2340-
auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R3);
2339+
auto *R2 = F_->createReshape("reshape3", Z2, reshape);
2340+
// Z2 is only used by R2.
2341+
// The Z2,R2 nodes will be replaced by a new splat node with the shape of R2.
2342+
auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R2);
23412343
auto *O = F_->createSave("ret", A3);
23422344

23432345
// Before optimization, we have 9 nodes in the graph.
@@ -2352,7 +2354,7 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
23522354
// replace by a new splat node.
23532355
EXPECT_EQ(F_->getNodes().size(), 8);
23542356

2355-
// The second input of A3 shoule be a splat node with a shape of R3.
2357+
// The second input of A3 shoule be a splat node with a shape of R2.
23562358
auto *newA3 = llvm::dyn_cast<AddNode>(O->getInput());
23572359
ASSERT_TRUE(newA3);
23582360
auto *SN = llvm::dyn_cast<SplatNode>(newA3->getRHS());
@@ -2362,8 +2364,8 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
23622364
// R1 should still be in the graph.
23632365
EXPECT_TRUE(functionContainsNode(F_, R1));
23642366

2365-
// R3 and Z2 should not be in the graph any more.
2366-
EXPECT_FALSE(functionContainsNode(F_, R3));
2367+
// R2 and Z2 should not be in the graph any more.
2368+
EXPECT_FALSE(functionContainsNode(F_, R2));
23672369
EXPECT_FALSE(functionContainsNode(F_, Z2));
23682370
}
23692371

@@ -8043,6 +8045,62 @@ TEST_F(GraphOptz, SinkReshapeBelowUnaryEltwiseOps) {
80438045
checkNumericalEquivalence(0.f);
80448046
}
80458047

8048+
TEST_F(GraphOptz, SinkReshapeBelowBinaryEltwiseOps) {
8049+
const dim_t dimsIn[] = {10, 10};
8050+
const dim_t dimsOut[] = {5, 5, 4};
8051+
8052+
// Prepare inputs.
8053+
auto *in1 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.12f, 0,
8054+
"in1", false);
8055+
auto *in2 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.17f, 0,
8056+
"in2", false);
8057+
auto *QCN =
8058+
mod_.createConstant(ElemKind::Int8QTy, dimsOut, 0.13f, 0, "quant_const");
8059+
auto *FCN = mod_.createConstant(ElemKind::FloatTy, dimsOut, "float_const");
8060+
auto qTy = mod_.uniqueType(ElemKind::Int8QTy, dimsOut, 0.15f, 0);
8061+
auto *QN = F_->createQuantize("quantize", FCN, qTy);
8062+
auto *SN = F_->createSplat("splat", qTy, 1.79f);
8063+
QCN->getHandle<int8_t>().randomize(-128, 127, mod_.getPRNG());
8064+
FCN->getHandle<float>().randomize(-1.f, 2.f, mod_.getPRNG());
8065+
8066+
// Test different combinations of Reshape, Constant, Quantize, Splat passed as
8067+
// LHS or RHS.
8068+
auto *RN1 = F_->createReshape("reshape", in1, dimsOut);
8069+
auto *RN2 = F_->createReshape("reshape", in2, dimsOut);
8070+
auto *AN = F_->createAdd("add", RN1, RN2);
8071+
auto *MLN = F_->createMul("mul", AN, QCN);
8072+
auto *MXN = F_->createMax("max", QN, MLN);
8073+
auto *SBN = F_->createSub("sub", MXN, SN);
8074+
auto *save = F_->createSave("ret", SBN);
8075+
8076+
optimizedF_ = optimizeFunctionForTest(F_);
8077+
8078+
auto *optSave =
8079+
llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
8080+
ASSERT_TRUE(optSave);
8081+
auto *optRN = llvm::dyn_cast<ReshapeNode>(optSave->getInput());
8082+
ASSERT_TRUE(optRN);
8083+
EXPECT_EQ(optRN->getResult().dims(), llvm::makeArrayRef(dimsOut));
8084+
auto *optSBN = llvm::dyn_cast<SubNode>(optRN->getInput());
8085+
ASSERT_TRUE(optSBN);
8086+
EXPECT_EQ(optSBN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8087+
auto *optMXN = llvm::dyn_cast<MaxNode>(optSBN->getLHS());
8088+
ASSERT_TRUE(optMXN);
8089+
EXPECT_EQ(optMXN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8090+
auto *optMLN = llvm::dyn_cast<MulNode>(optMXN->getRHS());
8091+
ASSERT_TRUE(optMLN);
8092+
EXPECT_EQ(optMLN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8093+
auto *optAN = llvm::dyn_cast<AddNode>(optMLN->getLHS());
8094+
ASSERT_TRUE(optAN);
8095+
EXPECT_EQ(optAN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8096+
8097+
bindings_.allocate(in1)->getHandle<int8_t>().randomize(-128, 127,
8098+
mod_.getPRNG());
8099+
bindings_.allocate(in2)->getHandle<int8_t>().randomize(-128, 127,
8100+
mod_.getPRNG());
8101+
checkNumericalEquivalence(0.f);
8102+
}
8103+
80468104
TEST_F(GraphOptz, OptConvertToDequantize) {
80478105
auto *I =
80488106
mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false);

tools/ClassGen/NodeGen.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,6 @@ int main(int argc, char **argv) {
673673
.addResultFromCtorArg()
674674
.dataParallel()
675675
.setDocstring("Performs element-wise exponential to the Input.");
676-
// clang-format on
677676

678677
BB.newNode("Logit")
679678
.addInput("Input")
@@ -685,8 +684,8 @@ int main(int argc, char **argv) {
685684
BB.newNode("NonZero")
686685
.addInput("Cond")
687686
.addResultFromCtorArg()
688-
.dataParallel()
689687
.setDocstring("Selects indices of the true elements in Cond");
688+
// clang-format on
690689

691690
BB.newNode("Select")
692691
.addInput("Cond")
@@ -1432,7 +1431,6 @@ int main(int argc, char **argv) {
14321431
.addInput("Input")
14331432
.addInput("Mapping")
14341433
.addResultFromCtorArg()
1435-
.dataParallel()
14361434
.setDocstring("Simple mapping between quantized numbers."
14371435
"This can be used as quantized sigmoid or tanh functions.");
14381436

0 commit comments

Comments
 (0)