Skip to content

[GraphOptimizer] Reshape sinking pass #5616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/glow/Optimizer/GraphOptimizer/FunctionPasses.def
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FUN_PASS(DCE)
FUN_PASS(SinkCode)
FUN_PASS(SinkConversions)
FUN_PASS(SinkReshapes)
FUN_PASS(HoistCode)
FUN_PASS(MergeMatMul)
FUN_PASS(MergePadIntoConvolution)
Expand Down
60 changes: 42 additions & 18 deletions lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1067,24 +1067,6 @@ bool SinkCode::run(Function *F, const CompilationContext &cctx) {
continue;
}
}

// Sink Clip below Reshape nodes.
if (auto *RN = dyn_cast<ReshapeNode>(node)) {
auto *CN = dyn_cast<ClipNode>(RN->getInput());
if (!CN) {
continue;
}

ReshapeNode *newRN = F->createReshape(RN->getName(), CN->getInput(),
RN->getDims(), RN->getLayout());
ClipNode *newCN = F->createClip(CN->getName(), newRN->getResult(),
CN->getMin(), CN->getMax());
RN->getResult().replaceAllUsesOfWith(newCN->getResult());
newRN->setPredicate(RN->getPredicate());
newCN->setPredicate(CN->getPredicate());
changed = true;
continue;
}
} // For all nodes in the graph.

// Transformations to sink nodes below Slice. Outlined into a separate loop to
Expand Down Expand Up @@ -1153,6 +1135,48 @@ bool HoistCode::run(Function *F, const CompilationContext &cctx) {
return changed;
}

/// Reshape Sinking.
bool SinkReshapes::run(Function *F, const CompilationContext &cctx) {
LOG_SCOPE(F->getLogContext(), getName());
bool changed = false;
auto &nodes = F->getNodes();
// For each node:
for (auto &N : nodes) {
auto *node = &N;

// Sink Reshape below eltwise nodes.
if (node->isDataParallel() && !node->hasSideEffects()) {
// Unary eltwise nodes.
if (node->getNumInputs() == 1 && node->getNumResults() == 1) {
auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0));
if (!RS) {
continue;
}

// Create new eltwise node.
auto in = RS->getInput();
auto out = node->getNthResult(0);
auto newTy =
F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims());
auto *newN = F->addNode(node->clone());
newN->setNthInput(0, in);
newN->setTypeUnsafe(0, newTy);
newN->setPredicate(node->getPredicate());

// Create new Reshape.
auto *newRS = F->createReshape(RS->getName(), newN,
RS->getResult().getType()->dims());
newRS->setPredicate(node->getPredicate());
out.replaceAllUsesOfWith(newRS->getResult());

changed = true;
continue;
}
}
}
return changed;
}

/// Remove unnecessary padding and reduce filters for Convolution nodes with
/// small input tensors.
bool OptimizeSmallConv::run(Function *F, const CompilationContext &cctx) {
Expand Down
8 changes: 8 additions & 0 deletions lib/Optimizer/GraphOptimizerPipeline/FunctionPassPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ createDefaultGraphOptimizationPassPipeline() {
// Run code hoisting pass to undo such unsuccessful sinking.
{FunctionPassID::HoistCode, ConvergenceMode::UntilFixedPoint},

// Try to eliminate Reshape nodes by sinking them through the graph.
// Such sinking can create new optimization opportunities as well as
// prevent some optimizations from happening, so do it at the very end of
// the pipeline to keep the current iteration unaffected and bear all
// benefits/consequences on the next pipeline iteration.
{FunctionPassID::SinkReshapes, ConvergenceMode::UntilFixedPoint},
{FunctionPassID::OptimizeReshape},

// Perform a round of Dead Code Elimination to cleanup the final pass.
getDCEPassConfig(),
};
Expand Down
67 changes: 40 additions & 27 deletions tests/unittests/GraphOptzTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6234,33 +6234,6 @@ TEST_F(GraphOptz, ParallelizeGraph_AvgPool_Model_Axis4) {
checkNumericalEquivalence(0.f);
}

TEST_F(GraphOptz, SinkClipBelowReshape) {
Placeholder *in =
mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", false);
ClipNode *clip = F_->createClip("clip", in, 0.2, 0.8);
ReshapeNode *reshape = F_->createReshape("reshape", clip, {2, 5});
SaveNode *save = F_->createSave("save", reshape);

optimizedF_ = optimizeFunctionForTest(F_);

// Same number of nodes, just swapped order.
EXPECT_EQ(F_->getNodes().size(), 3);
EXPECT_EQ(optimizedF_->getNodes().size(), 3);

const SaveNode *optSave =
findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
ASSERT_TRUE(optSave);
ClipNode *newClip = llvm::dyn_cast<ClipNode>(optSave->getInput());
ASSERT_TRUE(newClip);
ReshapeNode *newReshape = llvm::dyn_cast<ReshapeNode>(newClip->getInput());
ASSERT_TRUE(newReshape);
EXPECT_EQ(newReshape->getResult().dims(), reshape->getResult().dims());

bindings_.allocate(mod_.getPlaceholders());
bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
checkNumericalEquivalence();
}

/// Test that Add after ConvTranspose is folded into Bias add when the actual
/// Add is is a broadcast of the bias. Test \p RnL (right of left) side add.
static void foldConvTransposeAddIntoBiasAdd(PlaceholderBindings &bindings,
Expand Down Expand Up @@ -7977,6 +7950,46 @@ TEST_F(GraphOptz, SinkReshapeBelowConvertTo) {
checkNumericalEquivalence(0.f);
}

TEST_F(GraphOptz, SinkReshapeBelowUnaryEltwiseOps) {
const dim_t dimsIn[] = {10, 10};
const dim_t dimsOut[] = {5, 5, 4};

auto *in = mod_.createPlaceholder(ElemKind::FloatTy, dimsIn, "in", false);
auto *RN = F_->createReshape("reshape", in, dimsOut);
auto *AN = F_->createAbs("abs", RN);
auto *SN = F_->createSin("sin", AN);
auto *CN = F_->createClip("clip", SN, -4.f, 5.f);
auto *TN = F_->createTanh("tanh", CN);
auto *save = F_->createSave("ret", TN);

optimizedF_ = optimizeFunctionForTest(F_);

auto *optSave =
llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
ASSERT_TRUE(optSave);
auto *optRN = llvm::dyn_cast<ReshapeNode>(optSave->getInput());
ASSERT_TRUE(optRN);
EXPECT_EQ(optRN->getResult().dims(), llvm::makeArrayRef(dimsOut));
auto *optTN = llvm::dyn_cast<TanhNode>(optRN->getInput());
ASSERT_TRUE(optTN);
EXPECT_EQ(optTN->getResult().dims(), llvm::makeArrayRef(dimsIn));
auto *optCN = llvm::dyn_cast<ClipNode>(optTN->getInput());
ASSERT_TRUE(optCN);
EXPECT_FLOAT_EQ(optCN->getMin(), CN->getMin());
EXPECT_FLOAT_EQ(optCN->getMax(), CN->getMax());
EXPECT_EQ(optCN->getResult().dims(), llvm::makeArrayRef(dimsIn));
auto *optSN = llvm::dyn_cast<SinNode>(optCN->getInput());
ASSERT_TRUE(optSN);
EXPECT_EQ(optSN->getResult().dims(), llvm::makeArrayRef(dimsIn));
auto *optAN = llvm::dyn_cast<AbsNode>(optSN->getInput());
ASSERT_TRUE(optAN);
EXPECT_EQ(optAN->getResult().dims(), llvm::makeArrayRef(dimsIn));

bindings_.allocate(in)->getHandle<float>().randomize(-30.f, 30.f,
mod_.getPRNG());
checkNumericalEquivalence(0.f);
}

TEST_F(GraphOptz, OptConvertToDequantize) {
auto *I =
mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false);
Expand Down
10 changes: 4 additions & 6 deletions tests/unittests/QuantizationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2721,8 +2721,6 @@ TEST(Quantization, quantizeReshape) {
std::unique_ptr<Backend> backend(createBackend(EE.getBackendName()));
quantization::quantizeFunction(F, quantConfig, *backend);

optimize(F, CompilationMode::Infer);

{
// Verify that the output variable is not quantized, and that it has a
// single save node writer, which is also not quantized.
Expand All @@ -2739,16 +2737,16 @@ TEST(Quantization, quantizeReshape) {
// Verify that the reshape is rescaled after being quantized.
// The reason we need a rescale is because reshaping doesn't perform
// rescaling by itself.
// Note: after optimization, the RescaleQuantized node created for the
// Reshape gets merged with the dequantize node.
auto *qreshape = llvm::dyn_cast<ReshapeNode>(DN->getInput());
auto *RQ = llvm::dyn_cast<RescaleQuantizedNode>(DN->getInput());
ASSERT_TRUE(RQ);
auto *qreshape = llvm::dyn_cast<ReshapeNode>(RQ->getInput());
ASSERT_TRUE(qreshape);
ASSERT_TRUE(qreshape->getResult().getType()->isQuantizedType());
EXPECT_EQ(qreshape->getResult().getType()->getOffset(),
reshapeInpTQP.offset);
EXPECT_EQ(qreshape->getResult().getType()->getScale(), reshapeInpTQP.scale);

// Verify that the variable inputs to the matmul are quantized.
// Verify that the input to the reshape is quantized.
auto *qinput = llvm::dyn_cast<QuantizeNode>(qreshape->getInput());
ASSERT_TRUE(qinput);
EXPECT_EQ(qinput->getResult().getType()->getOffset(),
Expand Down