diff --git a/lib/Graph/TensorLayout.cpp b/lib/Graph/TensorLayout.cpp index 3da145fd78..7d29a74783 100644 --- a/lib/Graph/TensorLayout.cpp +++ b/lib/Graph/TensorLayout.cpp @@ -699,6 +699,8 @@ static bool acceptsAnyInputLayout(const glow::Node *node) { case Kinded::Kind::BatchNormalizationNodeKind: case Kinded::Kind::BatchNormalizationGradNodeKind: case Kinded::Kind::PadNodeKind: + case Kinded::Kind::NonZeroNodeKind: + case Kinded::Kind::IntLookupTableNodeKind: case Kinded::Kind::ReshapeNodeKind: case Kinded::Kind::MeanVarNormalizationNodeKind: case Kinded::Kind::MatMulNodeKind: diff --git a/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp b/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp index 8947208f32..acc4ef8a0b 100644 --- a/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp +++ b/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp @@ -1143,37 +1143,122 @@ bool SinkReshapes::run(Function *F, const CompilationContext &cctx) { auto *node = &N; // Sink Reshape below eltwise nodes. - if (!node->isDataParallel() || node->hasSideEffects()) { - continue; - } + if (node->isDataParallel() && !node->hasSideEffects()) { + // Unary eltwise nodes. + if (node->getNumInputs() == 1 && node->getNumResults() == 1) { + DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() || + node->getNthInput(0).getType()->isFusedQuantizedType() || + node->getNthResult(0).dims().equals(node->getNthInput(0).dims())) + << "SinkReshapes: not an element-wise node: " << node->toString(); + + auto *RS = dyn_cast(node->getNthInput(0)); + if (!RS) { + continue; + } - // Unary eltwise nodes. - if (node->getNumInputs() != 1 || node->getNumResults() != 1) { - continue; - } + // Create new eltwise node. + auto in = RS->getInput(); + auto out = node->getNthResult(0); + auto newTy = + F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims()); + auto *newN = F->addNode(node->clone()); + newN->setNthInput(0, in); + newN->setTypeUnsafe(0, newTy); + newN->setPredicate(node->getPredicate()); + + // Create new Reshape. + auto *newRS = F->createReshape(RS->getName(), newN, + RS->getResult().getType()->dims()); + newRS->setPredicate(node->getPredicate()); + out.replaceAllUsesOfWith(newRS->getResult()); - auto *RS = dyn_cast(node->getNthInput(0)); - if (!RS) { - continue; - } + changed = true; + } + + // Binary eltwise nodes. + if (node->getNumInputs() == 2 && node->getNumResults() == 1) { + DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() || + node->getNthInput(0).getType()->isFusedQuantizedType() || + node->getNthResult(0).dims().equals(node->getNthInput(0).dims())) + << "SinkReshapes: not an element-wise node: " << node->toString(); + DCHECK(node->getNthResult(0).getType()->isFusedQuantizedType() || + node->getNthInput(1).getType()->isFusedQuantizedType() || + node->getNthResult(0).dims().equals(node->getNthInput(1).dims())) + << "SinkReshapes: not an element-wise node: " << node->toString(); + + // At least one of the inputs must be a Reshape. + // If both inputs are Reshapes, they must have the same dimensions. + auto *LRN = dyn_cast(node->getNthInput(0)); + auto *RRN = dyn_cast(node->getNthInput(1)); + if (!LRN && !RRN) { + continue; + } + if (LRN && RRN && + !LRN->getResult().dims().equals(RRN->getResult().dims())) { + continue; + } - // Create new eltwise node. - auto in = RS->getInput(); - auto out = node->getNthResult(0); - auto newTy = - F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims()); - auto *newN = F->addNode(node->clone()); - newN->setNthInput(0, in); - newN->setTypeUnsafe(0, newTy); - newN->setPredicate(node->getPredicate()); + // Canonicalize node to simplify transformation implementation (make LHS + // always be the input with a Reshape). + bool swap = (LRN == nullptr); + auto nv = node->getNthInput(1); + if (swap) { + nv = node->getNthInput(0); + LRN = RRN; + RRN = nullptr; + } - // Create new Reshape. - auto *newRS = F->createReshape(RS->getName(), newN, - RS->getResult().getType()->dims()); - newRS->setPredicate(node->getPredicate()); - out.replaceAllUsesOfWith(newRS->getResult()); + // RHS must be either a Reshape or a Constant (+ Quantize) or Splat. + auto *RQ = dyn_cast(nv); + auto *RC = dyn_cast(RQ ? RQ->getInput() : nv); + auto *RS = dyn_cast(nv); + if (!RRN && !RC && !RS) { + continue; + } - changed = true; + // Create new Constant, Quantize or Splat, if needed. + NodeValue rhs; + if (RRN) { + rhs = RRN->getInput(); + } + if (RC) { + auto ty = F->getParent()->uniqueTypeWithNewShape( + RC->getType(), LRN->getInput().dims()); + auto *newC = F->getParent()->createConstant(ty, RC->getName()); + newC->getPayloadMutable().copyRawFrom(&RC->getPayload()); + rhs = newC->getOutput(); + } + if (RQ) { + auto ty = F->getParent()->uniqueTypeWithNewShape( + RQ->getResult().getType(), LRN->getInput().dims()); + rhs = F->createQuantize(RQ->getName(), rhs, ty); + } + if (RS) { + auto ty = F->getParent()->uniqueTypeWithNewShape( + RS->getResult().getType(), LRN->getInput().dims()); + rhs = F->createSplat(RS->getName(), ty, RS->getValue()); + } + + // Create new eltwise node. + auto lhs = LRN->getInput(); + auto out = node->getNthResult(0); + auto newTy = + F->getParent()->uniqueTypeWithNewShape(out.getType(), lhs.dims()); + auto *newN = F->addNode(node->clone()); + newN->setNthInput(0, swap ? rhs : lhs); + newN->setNthInput(1, swap ? lhs : rhs); + newN->setTypeUnsafe(0, newTy); + newN->setPredicate(node->getPredicate()); + + // Create new Reshape. + auto *newRN = F->createReshape(LRN->getName(), newN, + LRN->getResult().getType()->dims()); + newRN->setPredicate(node->getPredicate()); + out.replaceAllUsesOfWith(newRN->getResult()); + + changed = true; + } + } } return changed; } diff --git a/tests/unittests/GraphOptzTest.cpp b/tests/unittests/GraphOptzTest.cpp index 35d5614120..d1edd7fea8 100644 --- a/tests/unittests/GraphOptzTest.cpp +++ b/tests/unittests/GraphOptzTest.cpp @@ -2323,21 +2323,23 @@ TEST_F(GraphOptz, ReshapeAfterSplat) { const dim_t reshape[] = {1, 6000}; Type t1(ElemKind::FloatTy, shape); Type t2(ElemKind::FloatTy, reshape); - Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape, - "input", true); + Node *input1 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape, + "input1", true); + Node *input2 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, reshape, + "input2", true); auto *Z1 = F_->createSplat("zero1", &t1, 1.5); - auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input, Z1); + auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input1, Z1); auto *R1 = F_->createReshape("reshape1", Z1, reshape); // Z1 is used by R1 and A1. // The reshape optimization will thus NOT be able to remove this reshape node // (R1). - auto *R2 = F_->createReshape("reshape2", A1, reshape); - auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, R2); + F_->createSave("save", A1); + auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, input2); auto *Z2 = F_->createSplat("zero2", &t1, 2.5); - auto *R3 = F_->createReshape("reshape3", Z2, reshape); - // Z2 is only used by R3. - // The Z2,R3 nodes will be replaced by a new splat node with the shape of R3. - auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R3); + auto *R2 = F_->createReshape("reshape3", Z2, reshape); + // Z2 is only used by R2. + // The Z2,R2 nodes will be replaced by a new splat node with the shape of R2. + auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R2); auto *O = F_->createSave("ret", A3); // Before optimization, we have 9 nodes in the graph. @@ -2352,7 +2354,7 @@ TEST_F(GraphOptz, ReshapeAfterSplat) { // replace by a new splat node. EXPECT_EQ(F_->getNodes().size(), 8); - // The second input of A3 shoule be a splat node with a shape of R3. + // The second input of A3 shoule be a splat node with a shape of R2. auto *newA3 = llvm::dyn_cast(O->getInput()); ASSERT_TRUE(newA3); auto *SN = llvm::dyn_cast(newA3->getRHS()); @@ -2362,8 +2364,8 @@ TEST_F(GraphOptz, ReshapeAfterSplat) { // R1 should still be in the graph. EXPECT_TRUE(functionContainsNode(F_, R1)); - // R3 and Z2 should not be in the graph any more. - EXPECT_FALSE(functionContainsNode(F_, R3)); + // R2 and Z2 should not be in the graph any more. + EXPECT_FALSE(functionContainsNode(F_, R2)); EXPECT_FALSE(functionContainsNode(F_, Z2)); } @@ -8043,6 +8045,62 @@ TEST_F(GraphOptz, SinkReshapeBelowUnaryEltwiseOps) { checkNumericalEquivalence(0.f); } +TEST_F(GraphOptz, SinkReshapeBelowBinaryEltwiseOps) { + const dim_t dimsIn[] = {10, 10}; + const dim_t dimsOut[] = {5, 5, 4}; + + // Prepare inputs. + auto *in1 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.12f, 0, + "in1", false); + auto *in2 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.17f, 0, + "in2", false); + auto *QCN = + mod_.createConstant(ElemKind::Int8QTy, dimsOut, 0.13f, 0, "quant_const"); + auto *FCN = mod_.createConstant(ElemKind::FloatTy, dimsOut, "float_const"); + auto qTy = mod_.uniqueType(ElemKind::Int8QTy, dimsOut, 0.15f, 0); + auto *QN = F_->createQuantize("quantize", FCN, qTy); + auto *SN = F_->createSplat("splat", qTy, 1.79f); + QCN->getHandle().randomize(-128, 127, mod_.getPRNG()); + FCN->getHandle().randomize(-1.f, 2.f, mod_.getPRNG()); + + // Test different combinations of Reshape, Constant, Quantize, Splat passed as + // LHS or RHS. + auto *RN1 = F_->createReshape("reshape", in1, dimsOut); + auto *RN2 = F_->createReshape("reshape", in2, dimsOut); + auto *AN = F_->createAdd("add", RN1, RN2); + auto *MLN = F_->createMul("mul", AN, QCN); + auto *MXN = F_->createMax("max", QN, MLN); + auto *SBN = F_->createSub("sub", MXN, SN); + auto *save = F_->createSave("ret", SBN); + + optimizedF_ = optimizeFunctionForTest(F_); + + auto *optSave = + llvm::dyn_cast(optimizedF_->getNodeByName(save->getName())); + ASSERT_TRUE(optSave); + auto *optRN = llvm::dyn_cast(optSave->getInput()); + ASSERT_TRUE(optRN); + EXPECT_EQ(optRN->getResult().dims(), llvm::makeArrayRef(dimsOut)); + auto *optSBN = llvm::dyn_cast(optRN->getInput()); + ASSERT_TRUE(optSBN); + EXPECT_EQ(optSBN->getResult().dims(), llvm::makeArrayRef(dimsIn)); + auto *optMXN = llvm::dyn_cast(optSBN->getLHS()); + ASSERT_TRUE(optMXN); + EXPECT_EQ(optMXN->getResult().dims(), llvm::makeArrayRef(dimsIn)); + auto *optMLN = llvm::dyn_cast(optMXN->getRHS()); + ASSERT_TRUE(optMLN); + EXPECT_EQ(optMLN->getResult().dims(), llvm::makeArrayRef(dimsIn)); + auto *optAN = llvm::dyn_cast(optMLN->getLHS()); + ASSERT_TRUE(optAN); + EXPECT_EQ(optAN->getResult().dims(), llvm::makeArrayRef(dimsIn)); + + bindings_.allocate(in1)->getHandle().randomize(-128, 127, + mod_.getPRNG()); + bindings_.allocate(in2)->getHandle().randomize(-128, 127, + mod_.getPRNG()); + checkNumericalEquivalence(0.f); +} + TEST_F(GraphOptz, OptConvertToDequantize) { auto *I = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false); diff --git a/tools/ClassGen/NodeGen.cpp b/tools/ClassGen/NodeGen.cpp index def7b09b52..2bb0faa958 100644 --- a/tools/ClassGen/NodeGen.cpp +++ b/tools/ClassGen/NodeGen.cpp @@ -673,7 +673,6 @@ int main(int argc, char **argv) { .addResultFromCtorArg() .dataParallel() .setDocstring("Performs element-wise exponential to the Input."); - // clang-format on BB.newNode("Logit") .addInput("Input") @@ -685,8 +684,8 @@ int main(int argc, char **argv) { BB.newNode("NonZero") .addInput("Cond") .addResultFromCtorArg() - .dataParallel() .setDocstring("Selects indices of the true elements in Cond"); + // clang-format on BB.newNode("Select") .addInput("Cond") @@ -1432,7 +1431,6 @@ int main(int argc, char **argv) { .addInput("Input") .addInput("Mapping") .addResultFromCtorArg() - .dataParallel() .setDocstring("Simple mapping between quantized numbers." "This can be used as quantized sigmoid or tanh functions.");