Skip to content

Add improved support for parallelization and related graph opts #5257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/glow/Graph/NodeValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ struct NodeValue {
/// @{
ElemKind getElementType() const;
llvm::ArrayRef<dim_t> dims() const;
float getScale() const;
int32_t getOffset() const;
/// @}

bool operator==(const NodeValue &O) const {
Expand Down
18 changes: 18 additions & 0 deletions lib/Backends/NNPI/NNPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,24 @@ static void setupBasicParallelizationConfigs(
}
}

if (auto *R = llvm::dyn_cast<RescaleQuantizedNode>(node)) {
// For Rescales that are preceded by FC or Relu, mirror their
// parallelization.
Node *inputNode = R->getInput().getNode();
if (!llvm::isa<FullyConnectedNode>(inputNode) &&
!llvm::isa<ReluNode>(inputNode)) {
continue;
}
auto numChunksIt = numChunks.find(inputNode);
auto parOptsIt = parOpts.find(inputNode);
if (numChunksIt == numChunks.end() || parOptsIt == parOpts.end()) {
continue;
}
parOpts[R] = parOptsIt->second;
numChunks[R] = numChunksIt->second;
continue;
}

// Split Gelu layers in data parallel fashion
if (auto *GL = llvm::dyn_cast<GeluNode>(node)) {
size_t M = GL->getInput().dims()[0];
Expand Down
4 changes: 4 additions & 0 deletions lib/Graph/NodeValue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ ElemKind NodeValue::getElementType() const {
return getType()->getElementType();
}

float NodeValue::getScale() const { return getType()->getScale(); }

int32_t NodeValue::getOffset() const { return getType()->getOffset(); }

llvm::ArrayRef<dim_t> NodeValue::dims() const { return getType()->dims(); }

std::string
Expand Down
105 changes: 85 additions & 20 deletions lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,20 +895,45 @@ bool SinkCode::run(Function *F, const CompilationContext &cctx) {
node->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(newTR);
}

// Sink TransposeNode below QuantizedNode.
// If it doesn't work out it will be re-sinked later.
if (auto *Q = dyn_cast<QuantizeNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(Q->getInput());
if (!TR) {
// Sink TransposeNode below QuantizedNode.
// If it doesn't work out it will be re-sinked later.
if (auto *TR = dyn_cast<TransposeNode>(Q->getInput())) {
auto newQType = F->getParent()->uniqueTypeWithNewShape(
Q->getResult().getType(), TR->getInput().dims());
auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType);
auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle());
Q->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}

// Sink Reshape below Quantize.
if (auto *RN = dyn_cast<ReshapeNode>(Q->getInput())) {
auto newQType = F->getParent()->uniqueTypeWithNewShape(
Q->getResult().getType(), RN->getInput().dims());
auto *newQ = F->createQuantize(Q->getName(), RN->getInput(), newQType);
auto *newRN =
F->createReshape(RN->getName(), newQ, RN->getResult().dims());
Q->getResult().replaceAllUsesOfWith(newRN->getResult());
changed = true;
continue;
}
}

auto newQType = F->getParent()->uniqueTypeWithNewShape(
Q->getResult().getType(), TR->getInput().dims());
auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType);
auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle());
Q->getResult().replaceAllUsesOfWith(newTR);
// Sink Reshape below ConvertTo.
if (auto *CN = dyn_cast<ConvertToNode>(node)) {
auto *RN = dyn_cast<ReshapeNode>(CN->getInput());
if (!RN) {
continue;
}
auto *newCN = F->createConvertTo(CN->getName(), RN->getInput(),
CN->getResult().getElementType());
auto *newRN =
F->createReshape(RN->getName(), newCN, RN->getResult().dims());
CN->getResult().replaceAllUsesOfWith(newRN->getResult());
changed = true;
continue;
}

// Sink TransposeNode below DequantizedNode.
Expand Down Expand Up @@ -4000,28 +4025,47 @@ bool OptimizeConversions::run(Function *F, const CompilationContext &cctx) {
return changed;
}

/// Optimize Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is
/// int8. This may have numerical differences but since Int8 has a small range
/// it's likely fine. This is opt in by a backend.
/// Optimize patterns of Int8 quantization/dequantization with ConvertTo. This
/// may have numerical differences but since Int8 has a small range it's likely
/// fine. This is opt in by a backend.
bool OptimizeOutIntermediateConversions::run(Function *F,
const CompilationContext &cctx) {
LOG_SCOPE(F->getLogContext(), getName());

bool changed = false;
for (auto &node : F->getNodes()) {
QuantizeNode *QN = llvm::dyn_cast<QuantizeNode>(&node);
if (!QN ||
QN->getResult().getType()->getElementType() != ElemKind::Int8QTy) {
// Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8
if (QuantizeNode *QN = llvm::dyn_cast<QuantizeNode>(&node)) {
if (QN->getResult().getType()->getElementType() != ElemKind::Int8QTy) {
continue;
}

ConvertToNode *CN = llvm::dyn_cast<ConvertToNode>(QN->getInput());
if (!CN) {
continue;
}

QN->setNthInput(QuantizeNode::InputIdx, CN->getInput());
changed = true;
continue;
}

ConvertToNode *CN = llvm::dyn_cast<ConvertToNode>(QN->getInput());
if (!CN) {
// ConvertTo(Dequantize(Node)) -> Dequantize(Node), where Dequantize is int8
if (ConvertToNode *CN = llvm::dyn_cast<ConvertToNode>(&node)) {
DequantizeNode *DN = llvm::dyn_cast<DequantizeNode>(CN->getInput());
if (!DN ||
DN->getInput().getType()->getElementType() != ElemKind::Int8QTy) {
continue;
}

// Create new Dequantize node, dequantizing directly to the kind of the
// ConverTo that originally consumed it.
DequantizeNode *newDN = F->createDequantize(
DN->getName(), DN->getInput(), CN->getResult().getElementType());
CN->getResult().replaceAllUsesOfWith(newDN->getResult());
changed = true;
continue;
}

QN->setNthInput(QuantizeNode::InputIdx, CN->getInput());
changed = true;
}

return changed;
Expand Down Expand Up @@ -6320,6 +6364,14 @@ Expected<std::unordered_map<Node *, ConcatNode *>> glow::parallelizeOps(
DequantizeNode::ResultIdx, splitDims, 0));
break;
}
case Kinded::Kind::RescaleQuantizedNodeKind: {
splitDims[RescaleQuantizedNode::InputIdx] = 0;
ASSIGN_VALUE_OR_RETURN_ERR(
CN, parallelizeAndReplaceNode(
F, curNode, curNumOfChunks, RescaleQuantizedNode::InputIdx,
RescaleQuantizedNode::ResultIdx, splitDims, 0));
break;
}
case Kinded::Kind::ConvertToNodeKind: {
splitDims[ConvertToNode::InputIdx] = 0;
ASSIGN_VALUE_OR_RETURN_ERR(
Expand Down Expand Up @@ -6431,6 +6483,19 @@ Expected<std::unordered_map<Node *, ConcatNode *>> glow::parallelizeOps(
/*resultDim*/ 1, modelParallelSplitAlignment));
break;
}
case Kinded::Kind::RescaleQuantizedNodeKind: {
if (curNode->getNthInput(RescaleQuantizedNode::InputIdx).dims().size() <
2) {
break;
}
splitDims[RescaleQuantizedNode::InputIdx] = 1;
ASSIGN_VALUE_OR_RETURN_ERR(
CN, parallelizeAndReplaceNode(
F, curNode, curNumOfChunks, RescaleQuantizedNode::InputIdx,
RescaleQuantizedNode::ResultIdx, splitDims,
/*resultDim*/ 1, modelParallelSplitAlignment));
break;
}
default:
VLOG(1) << "Attempted to parallelize op type " << curNode->getKindName()
<< "not yet supported"
Expand Down
79 changes: 79 additions & 0 deletions tests/unittests/GraphOptzTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6697,3 +6697,82 @@ TEST_F(GraphOptz, TestUpdateQuantReluTypesChained) {
EXPECT_EQ(qReluTy->getScale(), qReshape->getResult().getType()->getScale());
EXPECT_EQ(qReluTy->getOffset(), qReshape->getResult().getType()->getOffset());
}

TEST_F(GraphOptz, SinkReshapeBelowQuantize) {
auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false);
auto *RN = F_->createReshape("reshape", I, {32, 64, 1});
auto *QN = F_->createQuantize("quantize", RN, ElemKind::Int8QTy, 0.2f, 1);
auto *SN = F_->createSave("ret", QN);

optimizedF_ = optimizeFunctionForTest(
F_, {FunctionPassID::SinkCode, getDCEPassConfig()});

auto *optSN =
llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
ASSERT_TRUE(optSN);
auto *optRN = llvm::dyn_cast<ReshapeNode>(optSN->getInput());
ASSERT_TRUE(optRN);
EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Int8QTy);
EXPECT_EQ(optRN->getResult().getScale(), 0.2f);
EXPECT_EQ(optRN->getResult().getOffset(), 1);
EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims());
auto *optQN = llvm::dyn_cast<QuantizeNode>(optRN->getInput());
ASSERT_TRUE(optQN);
EXPECT_EQ(optQN->getResult().getElementType(), ElemKind::Int8QTy);
EXPECT_EQ(optQN->getResult().getScale(), 0.2f);
EXPECT_EQ(optQN->getResult().getOffset(), 1);
EXPECT_EQ(optQN->getInput().getNode(), I);

bindings_.allocate(I)->getHandle<float>().randomize(-30, 30, mod_.getPRNG());
checkNumericalEquivalence(0.f);
}

TEST_F(GraphOptz, SinkReshapeBelowConvertTo) {
auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false);
auto *RN = F_->createReshape("reshape", I, {32, 64, 1});
auto *CN = F_->createConvertTo("convert", RN, ElemKind::Float16Ty);
auto *SN = F_->createSave("ret", CN);

optimizedF_ = optimizeFunctionForTest(
F_, {FunctionPassID::SinkCode, getDCEPassConfig()});

auto *optSN =
llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
ASSERT_TRUE(optSN);
auto *optRN = llvm::dyn_cast<ReshapeNode>(optSN->getInput());
ASSERT_TRUE(optRN);
EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Float16Ty);
EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims());
auto *optCN = llvm::dyn_cast<ConvertToNode>(optRN->getInput());
ASSERT_TRUE(optCN);
EXPECT_EQ(optCN->getResult().getElementType(), ElemKind::Float16Ty);
EXPECT_EQ(optCN->getInput().getNode(), I);

bindings_.allocate(I)->getHandle<float>().randomize(-30, 30, mod_.getPRNG());
checkNumericalEquivalence(0.f);
}

TEST_F(GraphOptz, OptConvertToDequantize) {
auto *I =
mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false);
auto *DN = F_->createDequantize("deq", I, ElemKind::Float16Ty);
auto *CN = F_->createConvertTo("convert", DN, ElemKind::FloatTy);
auto *SN = F_->createSave("ret", CN);

optimizedF_ = optimizeFunctionForTest(
F_,
{FunctionPassID::OptimizeOutIntermediateConversions, getDCEPassConfig()});

auto *optSN =
llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
ASSERT_TRUE(optSN);
auto *optDN = llvm::dyn_cast<DequantizeNode>(optSN->getInput());
ASSERT_TRUE(optDN);
EXPECT_EQ(optDN->getResult().getElementType(), ElemKind::FloatTy);
EXPECT_EQ(optDN->getResult().dims(), DN->getResult().dims());
EXPECT_EQ(optDN->getInput().getNode(), I);

bindings_.allocate(I)->getHandle<int8_t>().randomize(-128, 127,
mod_.getPRNG());
checkNumericalEquivalence(0.007f);
}
45 changes: 42 additions & 3 deletions tests/unittests/NNPIOptPipelineTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,9 +753,6 @@ TEST_F(NNPIOptPipelineTest, QuantizeFCDequantize) {
std::to_string(8);
cloneAndCompile();

F_->dumpDAG("tmp0.dot");
optimizedF_->dumpDAG("tmp1.dot");

EXPECT_EQ(countNodeKind(F_, Kinded::Kind::QuantizeNodeKind), 1);
EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 8);
EXPECT_EQ(countNodeKind(F_, Kinded::Kind::DequantizeNodeKind), 1);
Expand All @@ -765,6 +762,48 @@ TEST_F(NNPIOptPipelineTest, QuantizeFCDequantize) {
8);
}

TEST_F(NNPIOptPipelineTest, ParQuantizeFCReluRescaleSigmoidDequantize) {
auto *input =
mod_.createPlaceholder(ElemKind::Float16Ty, {32, 1024}, "input", false);
auto *weights = F_->getParent()->createConstant(
ElemKind::Int8QTy, {1024, 1024}, 0.2, 0, "weights");
auto *bias = F_->getParent()->createConstant(ElemKind::Int32QTy, {1024}, 0.2,
0, "bias");
weights->getPayloadMutable().getHandle<int8_t>().randomize(-127, 127,
mod_.getPRNG());
bias->getPayloadMutable().getHandle<int32_t>().randomize(-127, 127,
mod_.getPRNG());

auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {32, 1024}, 0.2, 0);
auto *quantized = F_->createQuantize("quantize", input, outTy);
auto *FC = F_->createFullyConnected("fc", quantized, weights, bias);
auto *relu = F_->createRELU("relu", FC);
auto rescaleTy =
mod_.uniqueType(ElemKind::Int8QTy, FC->getResult().dims(), 0.15, -1);
auto *rescale = F_->createRescaleQuantized("rescale", relu, rescaleTy);
auto *sigmoid = F_->createSigmoid("sig", rescale);
auto *dequantized =
F_->createDequantize("dequantize", sigmoid, ElemKind::Float16Ty);
F_->createSave("ret", dequantized);

cctx_.backendOpts.backendSpecificOpts["NNPINumParallelChunks"] =
std::to_string(2);
cloneAndCompile();
optimizedF_->dumpDAG("tmp1.dot");

#define CHECK_PAR(NAME_, PRE_, POST_) \
EXPECT_EQ(countNodeKind(F_, Kinded::Kind::NAME_##NodeKind), PRE_); \
EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::NAME_##NodeKind), POST_);

CHECK_PAR(Quantize, 1, 2);
CHECK_PAR(FullyConnected, 1, 2);
CHECK_PAR(Dequantize, 1, 2);
CHECK_PAR(RescaleQuantized, 1, 2);
CHECK_PAR(Relu, 1, 2);
CHECK_PAR(Sigmoid, 1, 2);
#undef CHECK_PAR
}

// BMM->clip
TEST_F(NNPIOptPipelineTest, BMMClip) {
auto *input0 =
Expand Down