Skip to content

Commit 1b7ecf2

Browse files
committed
Tensor Layouts: update PR based on Jordan's review
1 parent 284d165 commit 1b7ecf2

File tree

7 files changed

+20
-17
lines changed

7 files changed

+20
-17
lines changed

docs/TensorLayout.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Glow's string-based layout format is encoded as follows:
2525

2626
1. A mandatory one character representing the current dimension. Either an alphabetic letter or `*` (any layout).
2727
2. An optional token for the start of the current dimension's information: `[`.
28-
3. An optional namespace identifier for non-standard information, such as tiling, followed by `:`. Must have `[` from 2. in place. following said identifier, all subsequent data is considered as a "black box" until `]` is encountered.
28+
3. An optional namespace identifier for non-standard information, such as tiling, followed by `:`. Must have `[` from 2. in place. Following said identifier, all subsequent data is considered as a "black box" until `]` is encountered.
2929
4. Given that we have `[` from 2. in place, the closing bracket `]` for it.
3030
5. Optionally go back to 2.
3131

@@ -70,9 +70,10 @@ Which includes the following virtual methods they can override:
7070
virtual bool isSatisfiedBy(TypeRef ty,
7171
const TensorLayoutDescription &destLayout,
7272
const TensorLayoutDescription *srcLayout) const
73+
```
7374
- This function checks if `ty` satisfies `destLayout` layout requirements, if `srcLayout` is provided for `ty`, take that into account.
7475
75-
- `virtual std::array<TensorLayoutDescription, max_tensor_dimensions + 1> &getLayoutsForDims() const`
76+
- `virtual llvm::ArrayRef<TensorLayoutDescription> getLayoutsForDims() const`
7677
7778
- This helper function returns an array of predefined layouts for all dimensions from `0-D` to Glow's max tensor layout dimension.
7879

examples/fr2en.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,15 @@ void Model::loadEncoder() {
277277
{0, step, 0}, {batchSize_, step + 1, EMBEDDING_SIZE});
278278
Node *reshape =
279279
F_->createReshape("encoder." + std::to_string(step) + ".reshape",
280-
inputSlice, {batchSize_, EMBEDDING_SIZE}, "*");
280+
inputSlice, {batchSize_, EMBEDDING_SIZE}, ANY_LAYOUT);
281281
hidden = createPyTorchGRUCell(F_, reshape, hidden, wIh, bIh, wHh, bHh);
282282
outputs.push_back(hidden);
283283
}
284284

285285
Node *output = F_->createConcat("encoder.output", outputs, 1);
286-
Node *r2 = F_->createReshape("encoder.output.r2", output,
287-
{MAX_LENGTH * batchSize_, EMBEDDING_SIZE}, "*");
286+
Node *r2 =
287+
F_->createReshape("encoder.output.r2", output,
288+
{MAX_LENGTH * batchSize_, EMBEDDING_SIZE}, ANY_LAYOUT);
288289

289290
encoderHiddenOutput_ = F_->createGather("encoder.outputNth", r2, seqLength_);
290291
}
@@ -346,7 +347,7 @@ void Model::loadDecoder() {
346347

347348
Node *concat = F_->createConcat("decoder.output.concat", outputs, 0);
348349
Node *reshape = F_->createReshape("decoder.output.reshape", concat,
349-
{MAX_LENGTH, batchSize_}, "*");
350+
{MAX_LENGTH, batchSize_}, ANY_LAYOUT);
350351
auto *save = F_->createSave("decoder.output", reshape);
351352
output_ = save->getPlaceholder();
352353
bindings.allocate(output_);

include/glow/Graph/Graph.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,8 @@ Node *recursiveClone(Function *newF, Node *node, NodeMap &currToNew);
13881388
{ 0u, 3u, 1u, 2u }
13891389
#define HWCN2NHWC \
13901390
{ 3u, 0u, 1u, 2u }
1391+
#define NHWC2HWNC \
1392+
{ 1u, 2u, 0u, 3u }
13911393

13921394
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod);
13931395

@@ -1397,8 +1399,6 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F);
13971399

13981400
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F);
13991401

1400-
#define NHWC2HWNC \
1401-
{ 1u, 2u, 0u, 3u }
14021402
} // namespace glow
14031403

14041404
#endif // GLOW_GRAPH_GRAPH_H

lib/Backends/OpenCL/OpenCLTensorLayout.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ static const TensorLayoutDescription *getLayoutFromEnum(const N &node) {
6363
/// if it has one. Else returns nullptr. This will be removed and refactored
6464
/// if/when we move to using strings for all layout specifications and get rid
6565
/// of the enum.
66-
static const TensorLayoutDescription *getLayouForTempEnumRep(size_t n,
67-
const Node *node) {
66+
static const TensorLayoutDescription *
67+
getLayoutForTempEnumRep(size_t n, const Node *node) {
6868
if (const auto MP = llvm::dyn_cast<MaxPoolNode>(node)) {
6969
return getLayoutFromEnum(MP);
7070
}
@@ -96,7 +96,7 @@ std::string OpenCLTensorLayout::getNthInputLayoutRequirements(const Node *node,
9696
DCHECK_LE(dims.size(), max_tensor_dimensions) << "Too many dimensions";
9797
// TODO: Remove ->getLayout() enum and take a string like transpose. Refactor
9898
// the following after doing so.
99-
const auto *layout = getLayouForTempEnumRep(n, node);
99+
const auto *layout = getLayoutForTempEnumRep(n, node);
100100
if (layout) {
101101
return layout->getSerializedLayout();
102102
}
@@ -112,7 +112,7 @@ std::string OpenCLTensorLayout::getNthResultLayoutRequirements(const Node *node,
112112
DCHECK_LE(dims.size(), max_tensor_dimensions) << "Too many dimensions";
113113
// TODO: Remove ->getLayout() enum and take a string like transpose. Refactor
114114
// the following after doing so.
115-
const auto *layout = getLayouForTempEnumRep(n, node);
115+
const auto *layout = getLayoutForTempEnumRep(n, node);
116116
if (layout) {
117117
return layout->getSerializedLayout();
118118
}

lib/Graph/TensorLayout.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,8 @@ std::string TensorLayoutCommon::getNthResultLayoutRequirements(const Node *node,
431431
}
432432
// Dynamically form the layout description for transposes.
433433
auto input = TN->getInput();
434-
auto inputLayout = getNthInputLayoutRequirements(node, 0);
434+
auto inputLayout =
435+
getNthInputLayoutRequirements(node, TransposeNode::InputIdx);
435436
auto inputLayoutHelper = TensorLayoutDescription(inputLayout);
436437
llvm::SmallVector<std::string, max_tensor_dimensions> dims(
437438
input.dims().size());

lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,8 +1610,8 @@ static NodeValue tryToOptimizeConcatOfRehapes(Function *F, ConcatNode *CN) {
16101610
return F->createReshape(
16111611
CN->getInputs().front().getNode()->getName(), newCN,
16121612
CN->getResult().dims(),
1613-
CanonicalTensorLayout::getInstance().getNthResultLayoutRequirements(CN,
1614-
0));
1613+
CanonicalTensorLayout::getInstance().getNthResultLayoutRequirements(
1614+
CN, ConcatNode::ResultIdx));
16151615
}
16161616

16171617
/// Simplify concat node.

lib/Optimizer/GraphOptimizer/Lower.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ static void lowerFullyConnectedGradNode(Function *F, CompilationContext &cctx,
174174
auto *dx2 = F->createMatMul("fcg.dot", dout, wT);
175175
auto *dx = F->createReshape(
176176
"fcg.inG", dx2, FCG.getInput().getType()->dims(),
177-
CanonicalTensorLayout::getInstance().getNthInputLayoutRequirements(&FCG,
178-
0));
177+
CanonicalTensorLayout::getInstance().getNthInputLayoutRequirements(
178+
&FCG, FullyConnectedGradNode::InputIdx));
179179
replaceAllUsesOfWith(cctx.loweredInfoMap, FCG.getGradOfInputNamedInput(), dx);
180180

181181
// dw = xT * dout.

0 commit comments

Comments
 (0)