From 1915afa929cda249b87f3a637dcf49f036d24e6c Mon Sep 17 00:00:00 2001 From: Jordan Fix Date: Mon, 18 Jan 2021 11:59:16 -0800 Subject: [PATCH] Add improved support for parallelization and related graph opts Summary: - Add RescaleQuantized parallelization support to graph opts' parallelization code - On NNPI, mirror Rescale parallelization for FC/Relus that come before it - Sink Reshapes below Quantize and ConvertTo - Remove unnecessary ConvertTo when following a Dequantize (i.e. just change the elem kind of the Dequantize instead) Differential Revision: D25947824 fbshipit-source-id: 897e0aa507293647fdf5ff58d0119427dcee5aee --- include/glow/Graph/NodeValue.h | 2 + lib/Backends/NNPI/NNPI.cpp | 18 +++ lib/Graph/NodeValue.cpp | 4 + .../GraphOptimizer/GraphOptimizer.cpp | 105 ++++++++++++++---- tests/unittests/GraphOptzTest.cpp | 79 +++++++++++++ tests/unittests/NNPIOptPipelineTest.cpp | 45 +++++++- 6 files changed, 230 insertions(+), 23 deletions(-) diff --git a/include/glow/Graph/NodeValue.h b/include/glow/Graph/NodeValue.h index 761b7ca923..fd32efb419 100755 --- a/include/glow/Graph/NodeValue.h +++ b/include/glow/Graph/NodeValue.h @@ -96,6 +96,8 @@ struct NodeValue { /// @{ ElemKind getElementType() const; llvm::ArrayRef dims() const; + float getScale() const; + int32_t getOffset() const; /// @} bool operator==(const NodeValue &O) const { diff --git a/lib/Backends/NNPI/NNPI.cpp b/lib/Backends/NNPI/NNPI.cpp index ca7145bb4e..ee23114fa0 100644 --- a/lib/Backends/NNPI/NNPI.cpp +++ b/lib/Backends/NNPI/NNPI.cpp @@ -748,6 +748,24 @@ static void setupBasicParallelizationConfigs( } } + if (auto *R = llvm::dyn_cast(node)) { + // For Rescales that are preceded by FC or Relu, mirror their + // parallelization. + Node *inputNode = R->getInput().getNode(); + if (!llvm::isa(inputNode) && + !llvm::isa(inputNode)) { + continue; + } + auto numChunksIt = numChunks.find(inputNode); + auto parOptsIt = parOpts.find(inputNode); + if (numChunksIt == numChunks.end() || parOptsIt == parOpts.end()) { + continue; + } + parOpts[R] = parOptsIt->second; + numChunks[R] = numChunksIt->second; + continue; + } + // Split Gelu layers in data parallel fashion if (auto *GL = llvm::dyn_cast(node)) { size_t M = GL->getInput().dims()[0]; diff --git a/lib/Graph/NodeValue.cpp b/lib/Graph/NodeValue.cpp index 75fd33341b..d409858b57 100644 --- a/lib/Graph/NodeValue.cpp +++ b/lib/Graph/NodeValue.cpp @@ -111,6 +111,10 @@ ElemKind NodeValue::getElementType() const { return getType()->getElementType(); } +float NodeValue::getScale() const { return getType()->getScale(); } + +int32_t NodeValue::getOffset() const { return getType()->getOffset(); } + llvm::ArrayRef NodeValue::dims() const { return getType()->dims(); } std::string diff --git a/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp b/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp index a93c3f26b9..56f411b98f 100644 --- a/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp +++ b/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp @@ -895,20 +895,45 @@ bool SinkCode::run(Function *F, const CompilationContext &cctx) { node->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(newTR); } - // Sink TransposeNode below QuantizedNode. - // If it doesn't work out it will be re-sinked later. if (auto *Q = dyn_cast(node)) { - auto *TR = dyn_cast(Q->getInput()); - if (!TR) { + // Sink TransposeNode below QuantizedNode. + // If it doesn't work out it will be re-sinked later. + if (auto *TR = dyn_cast(Q->getInput())) { + auto newQType = F->getParent()->uniqueTypeWithNewShape( + Q->getResult().getType(), TR->getInput().dims()); + auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType); + auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle()); + Q->getResult().replaceAllUsesOfWith(newTR); + changed = true; + continue; + } + + // Sink Reshape below Quantize. + if (auto *RN = dyn_cast(Q->getInput())) { + auto newQType = F->getParent()->uniqueTypeWithNewShape( + Q->getResult().getType(), RN->getInput().dims()); + auto *newQ = F->createQuantize(Q->getName(), RN->getInput(), newQType); + auto *newRN = + F->createReshape(RN->getName(), newQ, RN->getResult().dims()); + Q->getResult().replaceAllUsesOfWith(newRN->getResult()); + changed = true; continue; } + } - auto newQType = F->getParent()->uniqueTypeWithNewShape( - Q->getResult().getType(), TR->getInput().dims()); - auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType); - auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle()); - Q->getResult().replaceAllUsesOfWith(newTR); + // Sink Reshape below ConvertTo. + if (auto *CN = dyn_cast(node)) { + auto *RN = dyn_cast(CN->getInput()); + if (!RN) { + continue; + } + auto *newCN = F->createConvertTo(CN->getName(), RN->getInput(), + CN->getResult().getElementType()); + auto *newRN = + F->createReshape(RN->getName(), newCN, RN->getResult().dims()); + CN->getResult().replaceAllUsesOfWith(newRN->getResult()); changed = true; + continue; } // Sink TransposeNode below DequantizedNode. @@ -4000,28 +4025,47 @@ bool OptimizeConversions::run(Function *F, const CompilationContext &cctx) { return changed; } -/// Optimize Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is -/// int8. This may have numerical differences but since Int8 has a small range -/// it's likely fine. This is opt in by a backend. +/// Optimize patterns of Int8 quantization/dequantization with ConvertTo. This +/// may have numerical differences but since Int8 has a small range it's likely +/// fine. This is opt in by a backend. bool OptimizeOutIntermediateConversions::run(Function *F, const CompilationContext &cctx) { LOG_SCOPE(F->getLogContext(), getName()); bool changed = false; for (auto &node : F->getNodes()) { - QuantizeNode *QN = llvm::dyn_cast(&node); - if (!QN || - QN->getResult().getType()->getElementType() != ElemKind::Int8QTy) { + // Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8 + if (QuantizeNode *QN = llvm::dyn_cast(&node)) { + if (QN->getResult().getType()->getElementType() != ElemKind::Int8QTy) { + continue; + } + + ConvertToNode *CN = llvm::dyn_cast(QN->getInput()); + if (!CN) { + continue; + } + + QN->setNthInput(QuantizeNode::InputIdx, CN->getInput()); + changed = true; continue; } - ConvertToNode *CN = llvm::dyn_cast(QN->getInput()); - if (!CN) { + // ConvertTo(Dequantize(Node)) -> Dequantize(Node), where Dequantize is int8 + if (ConvertToNode *CN = llvm::dyn_cast(&node)) { + DequantizeNode *DN = llvm::dyn_cast(CN->getInput()); + if (!DN || + DN->getInput().getType()->getElementType() != ElemKind::Int8QTy) { + continue; + } + + // Create new Dequantize node, dequantizing directly to the kind of the + // ConverTo that originally consumed it. + DequantizeNode *newDN = F->createDequantize( + DN->getName(), DN->getInput(), CN->getResult().getElementType()); + CN->getResult().replaceAllUsesOfWith(newDN->getResult()); + changed = true; continue; } - - QN->setNthInput(QuantizeNode::InputIdx, CN->getInput()); - changed = true; } return changed; @@ -6320,6 +6364,14 @@ Expected> glow::parallelizeOps( DequantizeNode::ResultIdx, splitDims, 0)); break; } + case Kinded::Kind::RescaleQuantizedNodeKind: { + splitDims[RescaleQuantizedNode::InputIdx] = 0; + ASSIGN_VALUE_OR_RETURN_ERR( + CN, parallelizeAndReplaceNode( + F, curNode, curNumOfChunks, RescaleQuantizedNode::InputIdx, + RescaleQuantizedNode::ResultIdx, splitDims, 0)); + break; + } case Kinded::Kind::ConvertToNodeKind: { splitDims[ConvertToNode::InputIdx] = 0; ASSIGN_VALUE_OR_RETURN_ERR( @@ -6431,6 +6483,19 @@ Expected> glow::parallelizeOps( /*resultDim*/ 1, modelParallelSplitAlignment)); break; } + case Kinded::Kind::RescaleQuantizedNodeKind: { + if (curNode->getNthInput(RescaleQuantizedNode::InputIdx).dims().size() < + 2) { + break; + } + splitDims[RescaleQuantizedNode::InputIdx] = 1; + ASSIGN_VALUE_OR_RETURN_ERR( + CN, parallelizeAndReplaceNode( + F, curNode, curNumOfChunks, RescaleQuantizedNode::InputIdx, + RescaleQuantizedNode::ResultIdx, splitDims, + /*resultDim*/ 1, modelParallelSplitAlignment)); + break; + } default: VLOG(1) << "Attempted to parallelize op type " << curNode->getKindName() << "not yet supported" diff --git a/tests/unittests/GraphOptzTest.cpp b/tests/unittests/GraphOptzTest.cpp index 7d8c197643..85b1d32b64 100644 --- a/tests/unittests/GraphOptzTest.cpp +++ b/tests/unittests/GraphOptzTest.cpp @@ -6697,3 +6697,82 @@ TEST_F(GraphOptz, TestUpdateQuantReluTypesChained) { EXPECT_EQ(qReluTy->getScale(), qReshape->getResult().getType()->getScale()); EXPECT_EQ(qReluTy->getOffset(), qReshape->getResult().getType()->getOffset()); } + +TEST_F(GraphOptz, SinkReshapeBelowQuantize) { + auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false); + auto *RN = F_->createReshape("reshape", I, {32, 64, 1}); + auto *QN = F_->createQuantize("quantize", RN, ElemKind::Int8QTy, 0.2f, 1); + auto *SN = F_->createSave("ret", QN); + + optimizedF_ = optimizeFunctionForTest( + F_, {FunctionPassID::SinkCode, getDCEPassConfig()}); + + auto *optSN = + llvm::dyn_cast(optimizedF_->getNodeByName(SN->getName())); + ASSERT_TRUE(optSN); + auto *optRN = llvm::dyn_cast(optSN->getInput()); + ASSERT_TRUE(optRN); + EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Int8QTy); + EXPECT_EQ(optRN->getResult().getScale(), 0.2f); + EXPECT_EQ(optRN->getResult().getOffset(), 1); + EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims()); + auto *optQN = llvm::dyn_cast(optRN->getInput()); + ASSERT_TRUE(optQN); + EXPECT_EQ(optQN->getResult().getElementType(), ElemKind::Int8QTy); + EXPECT_EQ(optQN->getResult().getScale(), 0.2f); + EXPECT_EQ(optQN->getResult().getOffset(), 1); + EXPECT_EQ(optQN->getInput().getNode(), I); + + bindings_.allocate(I)->getHandle().randomize(-30, 30, mod_.getPRNG()); + checkNumericalEquivalence(0.f); +} + +TEST_F(GraphOptz, SinkReshapeBelowConvertTo) { + auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false); + auto *RN = F_->createReshape("reshape", I, {32, 64, 1}); + auto *CN = F_->createConvertTo("convert", RN, ElemKind::Float16Ty); + auto *SN = F_->createSave("ret", CN); + + optimizedF_ = optimizeFunctionForTest( + F_, {FunctionPassID::SinkCode, getDCEPassConfig()}); + + auto *optSN = + llvm::dyn_cast(optimizedF_->getNodeByName(SN->getName())); + ASSERT_TRUE(optSN); + auto *optRN = llvm::dyn_cast(optSN->getInput()); + ASSERT_TRUE(optRN); + EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Float16Ty); + EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims()); + auto *optCN = llvm::dyn_cast(optRN->getInput()); + ASSERT_TRUE(optCN); + EXPECT_EQ(optCN->getResult().getElementType(), ElemKind::Float16Ty); + EXPECT_EQ(optCN->getInput().getNode(), I); + + bindings_.allocate(I)->getHandle().randomize(-30, 30, mod_.getPRNG()); + checkNumericalEquivalence(0.f); +} + +TEST_F(GraphOptz, OptConvertToDequantize) { + auto *I = + mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false); + auto *DN = F_->createDequantize("deq", I, ElemKind::Float16Ty); + auto *CN = F_->createConvertTo("convert", DN, ElemKind::FloatTy); + auto *SN = F_->createSave("ret", CN); + + optimizedF_ = optimizeFunctionForTest( + F_, + {FunctionPassID::OptimizeOutIntermediateConversions, getDCEPassConfig()}); + + auto *optSN = + llvm::dyn_cast(optimizedF_->getNodeByName(SN->getName())); + ASSERT_TRUE(optSN); + auto *optDN = llvm::dyn_cast(optSN->getInput()); + ASSERT_TRUE(optDN); + EXPECT_EQ(optDN->getResult().getElementType(), ElemKind::FloatTy); + EXPECT_EQ(optDN->getResult().dims(), DN->getResult().dims()); + EXPECT_EQ(optDN->getInput().getNode(), I); + + bindings_.allocate(I)->getHandle().randomize(-128, 127, + mod_.getPRNG()); + checkNumericalEquivalence(0.007f); +} diff --git a/tests/unittests/NNPIOptPipelineTest.cpp b/tests/unittests/NNPIOptPipelineTest.cpp index f427f095e8..b7c1b7b671 100644 --- a/tests/unittests/NNPIOptPipelineTest.cpp +++ b/tests/unittests/NNPIOptPipelineTest.cpp @@ -753,9 +753,6 @@ TEST_F(NNPIOptPipelineTest, QuantizeFCDequantize) { std::to_string(8); cloneAndCompile(); - F_->dumpDAG("tmp0.dot"); - optimizedF_->dumpDAG("tmp1.dot"); - EXPECT_EQ(countNodeKind(F_, Kinded::Kind::QuantizeNodeKind), 1); EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 8); EXPECT_EQ(countNodeKind(F_, Kinded::Kind::DequantizeNodeKind), 1); @@ -765,6 +762,48 @@ TEST_F(NNPIOptPipelineTest, QuantizeFCDequantize) { 8); } +TEST_F(NNPIOptPipelineTest, ParQuantizeFCReluRescaleSigmoidDequantize) { + auto *input = + mod_.createPlaceholder(ElemKind::Float16Ty, {32, 1024}, "input", false); + auto *weights = F_->getParent()->createConstant( + ElemKind::Int8QTy, {1024, 1024}, 0.2, 0, "weights"); + auto *bias = F_->getParent()->createConstant(ElemKind::Int32QTy, {1024}, 0.2, + 0, "bias"); + weights->getPayloadMutable().getHandle().randomize(-127, 127, + mod_.getPRNG()); + bias->getPayloadMutable().getHandle().randomize(-127, 127, + mod_.getPRNG()); + + auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {32, 1024}, 0.2, 0); + auto *quantized = F_->createQuantize("quantize", input, outTy); + auto *FC = F_->createFullyConnected("fc", quantized, weights, bias); + auto *relu = F_->createRELU("relu", FC); + auto rescaleTy = + mod_.uniqueType(ElemKind::Int8QTy, FC->getResult().dims(), 0.15, -1); + auto *rescale = F_->createRescaleQuantized("rescale", relu, rescaleTy); + auto *sigmoid = F_->createSigmoid("sig", rescale); + auto *dequantized = + F_->createDequantize("dequantize", sigmoid, ElemKind::Float16Ty); + F_->createSave("ret", dequantized); + + cctx_.backendOpts.backendSpecificOpts["NNPINumParallelChunks"] = + std::to_string(2); + cloneAndCompile(); + optimizedF_->dumpDAG("tmp1.dot"); + +#define CHECK_PAR(NAME_, PRE_, POST_) \ + EXPECT_EQ(countNodeKind(F_, Kinded::Kind::NAME_##NodeKind), PRE_); \ + EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::NAME_##NodeKind), POST_); + + CHECK_PAR(Quantize, 1, 2); + CHECK_PAR(FullyConnected, 1, 2); + CHECK_PAR(Dequantize, 1, 2); + CHECK_PAR(RescaleQuantized, 1, 2); + CHECK_PAR(Relu, 1, 2); + CHECK_PAR(Sigmoid, 1, 2); +#undef CHECK_PAR +} + // BMM->clip TEST_F(NNPIOptPipelineTest, BMMClip) { auto *input0 =