diff --git a/lib/Graph/TensorLayout.cpp b/lib/Graph/TensorLayout.cpp index e1e3b6496a..792626c6ef 100644 --- a/lib/Graph/TensorLayout.cpp +++ b/lib/Graph/TensorLayout.cpp @@ -623,8 +623,10 @@ static bool acceptsAnyInputLayout(const glow::Node *node) { case Kinded::Kind::BatchedReduceMinNodeKind: case Kinded::Kind::BatchNormalizationNodeKind: case Kinded::Kind::BatchNormalizationGradNodeKind: + case Kinded::Kind::PadNodeKind: case Kinded::Kind::ReshapeNodeKind: case Kinded::Kind::MeanVarNormalizationNodeKind: + case Kinded::Kind::MatMulNodeKind: case Kinded::Kind::SGDNodeKind: { return true; } diff --git a/tests/unittests/TensorLayoutTest.cpp b/tests/unittests/TensorLayoutTest.cpp index 2306380274..f3c067ffd3 100644 --- a/tests/unittests/TensorLayoutTest.cpp +++ b/tests/unittests/TensorLayoutTest.cpp @@ -59,6 +59,24 @@ TEST_P(TensorLayoutTest, convDefault) { EXPECT_TRUE(verifyLayouts(*F_, CanonicalTensorLayout::getInstance())); } +// Check that pad nodes accept any layout: +TEST_P(TensorLayoutTest, pad) { + CHECK_IF_ENABLED(); + + const size_t inputDims[] = {1, 10, 15, 5}; + const size_t outPadDims[] = {5, 18, 25, 11}; + int pads[] = {0, 2, 3, 1, 4, 6, 7, 5}; + + Node *A = mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false, + "NCHW"); + auto outTy = mod_.uniqueType(ElemKind::FloatTy, outPadDims); + Node *P = F_->createPad("pad", A, outTy, PaddingMode::CONSTANT, pads, 23.f); + SaveNode *S = F_->createSave("save", P); + bindings_.allocate(S->getPlaceholder()); + + EXPECT_TRUE(verifyLayouts(*F_, CanonicalTensorLayout::getInstance())); +} + static void buildBadConv(PlaceholderBindings &bindings, Module &mod, Function *F) { auto *input = mod.createPlaceholder(ElemKind::FloatTy, {1, 3, 3, 1}, "input",