@@ -2323,21 +2323,23 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
2323
2323
const dim_t reshape[] = {1, 6000};
2324
2324
Type t1(ElemKind::FloatTy, shape);
2325
2325
Type t2(ElemKind::FloatTy, reshape);
2326
- Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2327
- "input", true);
2326
+ Node *input1 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2327
+ "input1", true);
2328
+ Node *input2 = F_->getParent()->createPlaceholder(ElemKind::FloatTy, reshape,
2329
+ "input2", true);
2328
2330
auto *Z1 = F_->createSplat("zero1", &t1, 1.5);
2329
- auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input , Z1);
2331
+ auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input1 , Z1);
2330
2332
auto *R1 = F_->createReshape("reshape1", Z1, reshape);
2331
2333
// Z1 is used by R1 and A1.
2332
2334
// The reshape optimization will thus NOT be able to remove this reshape node
2333
2335
// (R1).
2334
- auto *R2 = F_->createReshape("reshape2 ", A1, reshape );
2335
- auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, R2 );
2336
+ F_->createSave("save ", A1);
2337
+ auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, input2 );
2336
2338
auto *Z2 = F_->createSplat("zero2", &t1, 2.5);
2337
- auto *R3 = F_->createReshape("reshape3", Z2, reshape);
2338
- // Z2 is only used by R3 .
2339
- // The Z2,R3 nodes will be replaced by a new splat node with the shape of R3 .
2340
- auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R3 );
2339
+ auto *R2 = F_->createReshape("reshape3", Z2, reshape);
2340
+ // Z2 is only used by R2 .
2341
+ // The Z2,R2 nodes will be replaced by a new splat node with the shape of R2 .
2342
+ auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R2 );
2341
2343
auto *O = F_->createSave("ret", A3);
2342
2344
2343
2345
// Before optimization, we have 9 nodes in the graph.
@@ -2352,7 +2354,7 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
2352
2354
// replace by a new splat node.
2353
2355
EXPECT_EQ(F_->getNodes().size(), 8);
2354
2356
2355
- // The second input of A3 shoule be a splat node with a shape of R3 .
2357
+ // The second input of A3 shoule be a splat node with a shape of R2 .
2356
2358
auto *newA3 = llvm::dyn_cast<AddNode>(O->getInput());
2357
2359
ASSERT_TRUE(newA3);
2358
2360
auto *SN = llvm::dyn_cast<SplatNode>(newA3->getRHS());
@@ -2362,8 +2364,8 @@ TEST_F(GraphOptz, ReshapeAfterSplat) {
2362
2364
// R1 should still be in the graph.
2363
2365
EXPECT_TRUE(functionContainsNode(F_, R1));
2364
2366
2365
- // R3 and Z2 should not be in the graph any more.
2366
- EXPECT_FALSE(functionContainsNode(F_, R3 ));
2367
+ // R2 and Z2 should not be in the graph any more.
2368
+ EXPECT_FALSE(functionContainsNode(F_, R2 ));
2367
2369
EXPECT_FALSE(functionContainsNode(F_, Z2));
2368
2370
}
2369
2371
@@ -8043,6 +8045,62 @@ TEST_F(GraphOptz, SinkReshapeBelowUnaryEltwiseOps) {
8043
8045
checkNumericalEquivalence(0.f);
8044
8046
}
8045
8047
8048
+ TEST_F(GraphOptz, SinkReshapeBelowBinaryEltwiseOps) {
8049
+ const dim_t dimsIn[] = {10, 10};
8050
+ const dim_t dimsOut[] = {5, 5, 4};
8051
+
8052
+ // Prepare inputs.
8053
+ auto *in1 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.12f, 0,
8054
+ "in1", false);
8055
+ auto *in2 = mod_.createPlaceholder(glow::ElemKind::Int8QTy, dimsIn, 0.17f, 0,
8056
+ "in2", false);
8057
+ auto *QCN =
8058
+ mod_.createConstant(ElemKind::Int8QTy, dimsOut, 0.13f, 0, "quant_const");
8059
+ auto *FCN = mod_.createConstant(ElemKind::FloatTy, dimsOut, "float_const");
8060
+ auto qTy = mod_.uniqueType(ElemKind::Int8QTy, dimsOut, 0.15f, 0);
8061
+ auto *QN = F_->createQuantize("quantize", FCN, qTy);
8062
+ auto *SN = F_->createSplat("splat", qTy, 1.79f);
8063
+ QCN->getHandle<int8_t>().randomize(-128, 127, mod_.getPRNG());
8064
+ FCN->getHandle<float>().randomize(-1.f, 2.f, mod_.getPRNG());
8065
+
8066
+ // Test different combinations of Reshape, Constant, Quantize, Splat passed as
8067
+ // LHS or RHS.
8068
+ auto *RN1 = F_->createReshape("reshape", in1, dimsOut);
8069
+ auto *RN2 = F_->createReshape("reshape", in2, dimsOut);
8070
+ auto *AN = F_->createAdd("add", RN1, RN2);
8071
+ auto *MLN = F_->createMul("mul", AN, QCN);
8072
+ auto *MXN = F_->createMax("max", QN, MLN);
8073
+ auto *SBN = F_->createSub("sub", MXN, SN);
8074
+ auto *save = F_->createSave("ret", SBN);
8075
+
8076
+ optimizedF_ = optimizeFunctionForTest(F_);
8077
+
8078
+ auto *optSave =
8079
+ llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
8080
+ ASSERT_TRUE(optSave);
8081
+ auto *optRN = llvm::dyn_cast<ReshapeNode>(optSave->getInput());
8082
+ ASSERT_TRUE(optRN);
8083
+ EXPECT_EQ(optRN->getResult().dims(), llvm::makeArrayRef(dimsOut));
8084
+ auto *optSBN = llvm::dyn_cast<SubNode>(optRN->getInput());
8085
+ ASSERT_TRUE(optSBN);
8086
+ EXPECT_EQ(optSBN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8087
+ auto *optMXN = llvm::dyn_cast<MaxNode>(optSBN->getLHS());
8088
+ ASSERT_TRUE(optMXN);
8089
+ EXPECT_EQ(optMXN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8090
+ auto *optMLN = llvm::dyn_cast<MulNode>(optMXN->getRHS());
8091
+ ASSERT_TRUE(optMLN);
8092
+ EXPECT_EQ(optMLN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8093
+ auto *optAN = llvm::dyn_cast<AddNode>(optMLN->getLHS());
8094
+ ASSERT_TRUE(optAN);
8095
+ EXPECT_EQ(optAN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8096
+
8097
+ bindings_.allocate(in1)->getHandle<int8_t>().randomize(-128, 127,
8098
+ mod_.getPRNG());
8099
+ bindings_.allocate(in2)->getHandle<int8_t>().randomize(-128, 127,
8100
+ mod_.getPRNG());
8101
+ checkNumericalEquivalence(0.f);
8102
+ }
8103
+
8046
8104
TEST_F(GraphOptz, OptConvertToDequantize) {
8047
8105
auto *I =
8048
8106
mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false);
0 commit comments