Skip to content

Commit 134e688

Browse files
committed
[TensorLayout] Propagate the input layout requirements for convertTo nodes
1 parent 973eb21 commit 134e688

File tree

5 files changed

+62
-1
lines changed

5 files changed

+62
-1
lines changed

docs/Backends.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ BB.newBackendSpecificNode("CPUMaxSplat")
195195
.setDocstring("A Max node with one splat input; CPU specific.");
196196
```
197197

198+
If tensor layout requirements are enabled for the backend, on should take
199+
special care of updating the layout verifier when adding a new node.
200+
See `TensorLayout.md` for more information.
201+
To extend the example above, if the new node is data parallel, a `.dataParallel()`
202+
line should be added.
203+
198204
During `transformPostLowering()`, this `CPUMaxSplat` node replaces the
199205
aforementioned pattern. However, there must be a corresponding instruction for
200206
this Node to be lowered to during the IRGen phase. Thus, we need a corresponding

docs/NewOperators.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#### High level IR
99
* Create a new Glow high level IR node in `ClassGen/NodeGen.cpp`. Run `ninja all` to generate the node. In the build directory, check `glow/AutoGenNodes.h` to ensure the node has been generated.
1010
* Implement the `verify()` method for the new node in `Graph/Nodes.cpp`.
11+
* Implement any node layout requirements, if any, see `TensorLayout.md` for details.
12+
Specifically see the notes section under `Canonical Tensor Layout`.
1113
* Implement a node creation method in `Graph/Graph.cpp`.
1214
* Implement logic to load model that contains the operator in `Importer/Caffe2ModelLoader.cpp` or `Importer/ONNXModelLoader.cpp` depending on which type of model the operator comes from. Add the operator to `Importer/CommonOperatorLoader.h` instead if the loading logic can be shared between Caffe2 and ONNX. Add as much validation logic as possible here in the loader for the operator because it's crucial to catch errors at this stage. Once the operator is loaded, it is assumed that Glow will be able to successfully run the operator so any issues must be caught here.
1315
#### Low level IR

docs/TensorLayout.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ derives from `TensorLayoutCommon` and overrides the following functions:
108108
- This function takes an operator `Node *node` and returns the layout requirements of the Nth result `n`.
109109
- It returns Common layout constraints, for example, `ConvolutionNode` should be in `NHWC` format.
110110
111+
Notes:
112+
113+
1. Some nodes can accept any layout as input, they are either data parallel, e.g. `Add`,
114+
or, while not data parallel, do not care about the order of dimensions for their operation,
115+
e.g. `ReshapeNodeKind`. When adding new nodes to Glow, such a behavior should be explicitly
116+
specified, by adding `.dataParallel()` in NodeGen for example.
117+
2. Some nodes propagate the layout information of their input, e.g. `convertTo` node,
118+
when adding such nodes to Glow the canonical layout verifier should be aware of them.
119+
We currently do that in `getNthInputLayoutRequirements`.
120+
111121
## Placeholders and Constants
112122
113123
An important thing to note is that some operators may have a `Placeholder` or

lib/Graph/TensorLayout.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,10 @@ std::string TensorLayoutCommon::getNthInputLayoutRequirements(const Node *node,
455455
auto input = QN->getInput();
456456
return getNthResultLayoutRequirements(input.getNode(), input.getResNo());
457457
}
458+
if (const auto *CTN = llvm::dyn_cast<ConvertToNode>(node)) {
459+
auto input = CTN->getInput();
460+
return getNthResultLayoutRequirements(input.getNode(), input.getResNo());
461+
}
458462
if (const auto *QPN = llvm::dyn_cast<QuantizationProfileNode>(node)) {
459463
switch (n) {
460464
case QuantizationProfileNode::InputIndices::InputIdx: {
@@ -478,6 +482,19 @@ static unsigned getInputIdx(const Node *N, NodeValue in) {
478482
return N->getNumInputs();
479483
}
480484

485+
/// \returns true if getting the input's layout would cause an infinite loop.
486+
static bool inputDoesNotKnowRequirements(const Node *node) {
487+
switch (node->getKind()) {
488+
case Kinded::Kind::TransposeNodeKind:
489+
case Kinded::Kind::QuantizeNodeKind:
490+
case Kinded::Kind::QuantizationProfileNodeKind:
491+
case Kinded::Kind::ConvertToNodeKind:
492+
return true;
493+
default:
494+
return false;
495+
}
496+
}
497+
481498
std::string TensorLayoutCommon::getNthResultLayoutRequirements(const Node *node,
482499
size_t n) {
483500
DCHECK_LT(n, node->getNumResults()) << "Wrong output number";
@@ -492,6 +509,9 @@ std::string TensorLayoutCommon::getNthResultLayoutRequirements(const Node *node,
492509
}
493510
// Dynamically form the layout description for transposes.
494511
auto input = TN->getInput();
512+
while (inputDoesNotKnowRequirements(input)) {
513+
input = input.getNode()->getNthInput(0);
514+
}
495515
auto inputLayout =
496516
getNthInputLayoutRequirements(node, TransposeNode::InputIdx);
497517
auto inputLayoutHelper = TensorLayoutDescription(inputLayout);
@@ -524,7 +544,8 @@ std::string TensorLayoutCommon::getNthResultLayoutRequirements(const Node *node,
524544
auto result = node->getNthResult(n);
525545
auto *user = (*result.getUsers().begin()).getUser();
526546
int inputIdx = getInputIdx(user, result);
527-
if (inputIdx >= user->getNumInputs() || llvm::isa<TransposeNode>(user)) {
547+
if (inputDoesNotKnowRequirements(user) ||
548+
inputIdx >= user->getNumInputs() || llvm::isa<TransposeNode>(user)) {
528549
return getLayoutsForDims()[dims.size()].getSerializedLayout();
529550
}
530551
auto layout = getNthInputLayoutRequirements(user, inputIdx);

tests/unittests/TensorLayoutTest.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "BackendTestUtils.h"
1717

1818
#include "glow/Backend/Backend.h"
19+
#include "glow/Converter/Float16Converter.h"
20+
#include "glow/Converter/TypeAToTypeBFunctionConverter.h"
1921
#include "glow/Graph/Graph.h"
2022
#include "glow/Graph/TensorLayout.h"
2123
#include "llvm/Support/raw_ostream.h"
@@ -91,6 +93,26 @@ TEST_P(TensorLayoutTest, convBadLayout) {
9193
EXPECT_FALSE(verifyLayouts(*F_, CanonicalTensorLayout::getInstance(), false));
9294
}
9395

96+
// Check that we propagate the layout information for convertTo nodes:
97+
TEST_P(TensorLayoutTest, convertTo) {
98+
CHECK_IF_ENABLED();
99+
100+
auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 3, 1}, "input",
101+
false, "NWCH");
102+
auto *resultNCHW = F_->createTranspose("transposeInput", input, NHWC2NCHW);
103+
auto *save = F_->createSave("save", resultNCHW);
104+
bindings_.allocate(save->getPlaceholder());
105+
106+
EXPECT_TRUE(verifyLayouts(*F_, CanonicalTensorLayout::getInstance()));
107+
108+
PrecisionConfiguration precConfig;
109+
TypeAToTypeBFunctionConverter converter(*F_, ElemKind::FloatTy,
110+
ElemKind::Float16Ty, precConfig);
111+
converter.convert();
112+
113+
EXPECT_TRUE(verifyLayouts(*F_, CanonicalTensorLayout::getInstance()));
114+
}
115+
94116
// Check TensorLayoutDescription's parser with simple input.
95117
TEST_P(TensorLayoutTest, parseTestSimple) {
96118
CHECK_IF_ENABLED();

0 commit comments

Comments
 (0)