Skip to content

Commit e6b4f15

Browse files
authored
Merge pull request #2 from compnerd/explicit-ctor
IR: make ShapeNWHC constructor explicit (NFC)
2 parents 735e7a5 + a8e3e54 commit e6b4f15

File tree

5 files changed

+52
-39
lines changed

5 files changed

+52
-39
lines changed

include/glow/IR/Type.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,20 @@ struct ShapeNHWC {
2020
size_t h; // Height
2121
size_t w; // Width
2222
size_t c; // # of Channels
23-
ShapeNHWC(ArrayRef<size_t> shape) {
23+
24+
// TODO: deprecate this for std::array<size_t, 4>
25+
explicit ShapeNHWC(ArrayRef<size_t> shape) {
2426
assert(shape.size() == 4 && "Invalid shape");
2527
n = shape[0];
2628
h = shape[1];
2729
w = shape[2];
2830
c = shape[3];
2931
}
3032

33+
explicit ShapeNHWC(size_t samples, size_t height, size_t width,
34+
size_t channels)
35+
: n(samples), h(height), w(width), c(channels) {}
36+
3137
bool equals(const ShapeNHWC &other) const {
3238
return n == other.n && h == other.h && w == other.w && c == other.c;
3339
}

src/glow/IR/IRBuilder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void IRBuilder::deallocateActiveInstrs() {
2121
ConvolutionInst *IRBuilder::createConvOp(Value *input, size_t depth,
2222
size_t kernel, size_t stride,
2323
size_t pad) {
24-
ShapeNHWC idim = input->dims();
24+
ShapeNHWC idim = ShapeNHWC(input->dims());
2525
assert(idim.w >= kernel && idim.h >= kernel &&
2626
"buffer too small for selected stride");
2727

@@ -47,7 +47,7 @@ ConvolutionInst *IRBuilder::createConvOp(Value *input, size_t depth,
4747

4848
PoolInst *IRBuilder::createPoolOp(Value *input, PoolInst::OpKind kind,
4949
size_t kernel, size_t stride, size_t pad) {
50-
ShapeNHWC idim = input->dims();
50+
ShapeNHWC idim = ShapeNHWC(input->dims());
5151
assert(idim.w >= kernel && idim.h >= kernel &&
5252
"buffer too small for selected stride");
5353

src/glow/IR/Instrs.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ void ConvolutionInst::verify() const {
116116
(void)filter;
117117
(void)bias;
118118

119-
ShapeNHWC idim = src->getType()->dims();
120-
ShapeNHWC odim = dest->getType()->dims();
119+
ShapeNHWC idim(src->getType()->dims());
120+
ShapeNHWC odim(dest->getType()->dims());
121121
(void)odim;
122122
assert(idim.w >= kernel_ && idim.h >= kernel_ &&
123123
"buffer too small for selected stride");
124124

125125
auto outSz =
126126
ConvNode::calculateOutputDims(idim.h, idim.w, pad_, kernel_, stride_);
127-
ShapeNHWC exp = ArrayRef<size_t>{idim.n, outSz.first, outSz.second, depth_};
127+
ShapeNHWC exp(idim.n, outSz.first, outSz.second, depth_);
128128
(void)exp;
129129
assert(exp == odim && "Invalid output dimensions");
130130

@@ -140,15 +140,15 @@ void PoolInst::verify() const {
140140
Value *src = getOperand(1).first;
141141
Value *srcXY = getOperand(2).first;
142142
(void)srcXY;
143-
ShapeNHWC idim = src->getType()->dims();
144-
ShapeNHWC odim = dest->getType()->dims();
143+
ShapeNHWC idim = ShapeNHWC(src->getType()->dims());
144+
ShapeNHWC odim = ShapeNHWC(dest->getType()->dims());
145145
(void)odim;
146146
assert(idim.w >= kernel_ && idim.h >= kernel_ &&
147147
"buffer too small for selected stride");
148148

149149
auto outSz =
150150
ConvNode::calculateOutputDims(idim.h, idim.w, pad_, kernel_, stride_);
151-
ShapeNHWC exp = ArrayRef<size_t>{idim.n, outSz.first, outSz.second, idim.c};
151+
ShapeNHWC exp(idim.n, outSz.first, outSz.second, idim.c);
152152
(void)exp;
153153
assert(exp == odim && "Invalid output dimensions");
154154

src/glow/Interpreter/InterpreterNodes.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ void Interpreter::fwdConvolutionInst(Context *ctx, bool isTrain,
3737
size_t pad = I->getPad();
3838
size_t stride = I->getStride();
3939

40-
ShapeNHWC odim = outW.dims();
41-
ShapeNHWC idim = inW.dims();
40+
ShapeNHWC odim(outW.dims());
41+
ShapeNHWC idim(inW.dims());
4242

4343
// For each input in the batch:
4444
for (size_t n = 0; n < idim.n; n++) {
@@ -94,8 +94,8 @@ void Interpreter::bwdConvolutionInst(Context *ctx, const ConvolutionInst *I) {
9494
size_t pad = I->getPad();
9595
size_t stride = I->getStride();
9696

97-
ShapeNHWC odim = outW.dims();
98-
ShapeNHWC idim = inW.dims();
97+
ShapeNHWC odim(outW.dims());
98+
ShapeNHWC idim(inW.dims());
9999

100100
// For each input in the batch:
101101
for (size_t n = 0; n < odim.n; n++) {
@@ -154,8 +154,8 @@ void Interpreter::fwdPoolMax_impl(Context *ctx, const PoolInst *I) {
154154
auto inW = getWeightHandle(ctx, I->getSrc());
155155
auto outW = getWeightHandle(ctx, I->getDest());
156156

157-
ShapeNHWC odim = outW.dims();
158-
ShapeNHWC idim = inW.dims();
157+
ShapeNHWC odim(outW.dims());
158+
ShapeNHWC idim(inW.dims());
159159

160160
auto pad = I->getPad();
161161
auto filterSize = I->getKernel();
@@ -215,8 +215,8 @@ void Interpreter::fwdPoolAvg_impl(Context *ctx, const PoolInst *I) {
215215
auto inW = getWeightHandle(ctx, I->getSrc());
216216
auto outW = getWeightHandle(ctx, I->getDest());
217217

218-
ShapeNHWC odim = outW.dims();
219-
ShapeNHWC idim = inW.dims();
218+
ShapeNHWC odim(outW.dims());
219+
ShapeNHWC idim(inW.dims());
220220

221221
auto pad = I->getPad();
222222
auto filterSize = I->getKernel();
@@ -273,7 +273,7 @@ void Interpreter::bwdPoolMax_impl(Context *ctx, const PoolInst *I) {
273273
auto outW = getWeightHandle(ctx, I->getDest());
274274
auto outG = getGradHandle(ctx, I->getDest());
275275

276-
ShapeNHWC odim = outW.dims();
276+
ShapeNHWC odim(outW.dims());
277277

278278
auto SXY = getTensorForValue(I->srcXY())->getHandle<size_t>();
279279

@@ -305,8 +305,8 @@ void Interpreter::bwdPoolAvg_impl(Context *ctx, const PoolInst *I) {
305305
auto outW = getWeightHandle(ctx, I->getDest());
306306
auto outG = getGradHandle(ctx, I->getDest());
307307

308-
ShapeNHWC odim = outW.dims();
309-
ShapeNHWC idim = inW.dims();
308+
ShapeNHWC odim(outW.dims());
309+
ShapeNHWC idim(inW.dims());
310310

311311
auto pad = I->getPad();
312312
auto filterSize = I->getKernel();
@@ -852,8 +852,9 @@ void Interpreter::fwdLocalResponseNormalizationInst(
852852
auto outW = getWeightHandle(ctx, I->getDest());
853853
auto scaleCache = getWeightHandle(ctx, I->getScale());
854854

855-
ShapeNHWC odim = outW.dims();
856-
ShapeNHWC idim = inW.dims();
855+
ShapeNHWC odim(outW.dims());
856+
ShapeNHWC idim(inW.dims());
857+
857858
(void)odim;
858859

859860
// LRN node does not change the shape of the input.
@@ -918,7 +919,7 @@ void Interpreter::bwdLocalResponseNormalizationInst(
918919
auto outG = getGradHandle(ctx, I->getDest());
919920
auto scaleCache = getWeightHandle(ctx, I->getScale());
920921

921-
ShapeNHWC odim = outW.dims();
922+
ShapeNHWC odim(outW.dims());
922923

923924
auto halfWindowSize = I->gethalfWindowSize();
924925
auto beta = I->getBeta();

src/glow/Network/Nodes.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ConvNode::ConvNode(Network *N, NodeBase *input, size_t outDepth,
1212

1313
void ConvNode::init(Context *ctx) const {
1414
assert(input_ && input_->size(ctx) && "Invalid input");
15-
ShapeNHWC idim = input_->dims(ctx);
15+
ShapeNHWC idim(input_->dims(ctx));
1616
assert(idim.h >= filterSize_ && idim.w >= filterSize_ &&
1717
"buffer too small for selected stride");
1818

@@ -65,8 +65,8 @@ std::string ConvNode::getDebugRepr(Context *ctx) const {
6565
}
6666

6767
void ConvNode::forward(Context *ctx, PassKind kind) const {
68-
ShapeNHWC odim = dims(ctx);
69-
ShapeNHWC idim = input_->dims(ctx);
68+
ShapeNHWC odim(dims(ctx));
69+
ShapeNHWC idim(input_->dims(ctx));
7070

7171
auto inW = input_->getWeightHandle(ctx);
7272
auto outW = getWeightHandle(ctx);
@@ -114,8 +114,9 @@ void ConvNode::forward(Context *ctx, PassKind kind) const {
114114
}
115115

116116
void ConvNode::backward(Context *ctx) const {
117-
ShapeNHWC odim = dims(ctx);
118-
ShapeNHWC idim = input_->dims(ctx);
117+
ShapeNHWC odim(dims(ctx));
118+
ShapeNHWC idim(input_->dims(ctx));
119+
119120
auto inW = input_->getWeightHandle(ctx);
120121
auto inG = input_->getGradHandle(ctx);
121122
auto outG = getGradHandle(ctx);
@@ -176,7 +177,7 @@ MaxPoolNode::MaxPoolNode(Network *N, NodeBase *input, OpKind kind,
176177

177178
void MaxPoolNode::init(Context *ctx) const {
178179
assert(input_ && input_->size(ctx) && "Invalid input");
179-
ShapeNHWC idim = input_->dims(ctx);
180+
ShapeNHWC idim(input_->dims(ctx));
180181
assert(idim.w >= filterSize_ && idim.h >= filterSize_ &&
181182
"buffer too small for selected stride");
182183

@@ -229,8 +230,9 @@ void MaxPoolNode::backward(Context *ctx) const {
229230
}
230231

231232
void MaxPoolNode::forwardMax(Context *ctx) const {
232-
ShapeNHWC odim = dims(ctx);
233-
ShapeNHWC idim = input_->dims(ctx);
233+
ShapeNHWC odim(dims(ctx));
234+
ShapeNHWC idim(input_->dims(ctx));
235+
234236
auto inW = input_->getWeightHandle(ctx);
235237
auto outW = getWeightHandle(ctx);
236238

@@ -286,7 +288,8 @@ void MaxPoolNode::forwardMax(Context *ctx) const {
286288
}
287289

288290
void MaxPoolNode::backwardMax(Context *ctx) const {
289-
ShapeNHWC odim = dims(ctx);
291+
ShapeNHWC odim(dims(ctx));
292+
290293
auto inG = input_->getGradHandle(ctx);
291294
auto outG = getGradHandle(ctx);
292295

@@ -319,8 +322,9 @@ void MaxPoolNode::forwardAvg(Context *ctx) const {
319322
// Implement the avg pooling operation as defined here:
320323
// https://arxiv.org/abs/1312.4400
321324

322-
ShapeNHWC odim = dims(ctx);
323-
ShapeNHWC idim = input_->dims(ctx);
325+
ShapeNHWC odim(dims(ctx));
326+
ShapeNHWC idim(input_->dims(ctx));
327+
324328
auto inW = input_->getWeightHandle(ctx);
325329
auto outW = getWeightHandle(ctx);
326330

@@ -360,8 +364,9 @@ void MaxPoolNode::forwardAvg(Context *ctx) const {
360364
}
361365

362366
void MaxPoolNode::backwardAvg(Context *ctx) const {
363-
ShapeNHWC odim = dims(ctx);
364-
ShapeNHWC idim = input_->dims(ctx);
367+
ShapeNHWC odim(dims(ctx));
368+
ShapeNHWC idim(input_->dims(ctx));
369+
365370
auto inG = input_->getGradHandle(ctx);
366371
auto outG = getGradHandle(ctx);
367372
FloatTy filterArea = filterSize_ * filterSize_;
@@ -529,8 +534,9 @@ void LRNNode::forward(Context *ctx, PassKind kind) const {
529534
auto inW = input_->getWeightHandle(ctx);
530535
auto scaleCache = ctx->getTensor(&scale_)->getHandle<FloatTy>();
531536

532-
ShapeNHWC odim = dims(ctx);
533-
ShapeNHWC idim = input_->dims(ctx);
537+
ShapeNHWC odim(dims(ctx));
538+
ShapeNHWC idim(input_->dims(ctx));
539+
534540
(void)odim;
535541

536542
// LRN node does not change the shape of the input.
@@ -591,7 +597,7 @@ void LRNNode::backward(Context *ctx) const {
591597
auto inW = input_->getWeightHandle(ctx);
592598
auto scaleCache = ctx->getTensor(&scale_)->getHandle<FloatTy>();
593599

594-
ShapeNHWC odim = dims(ctx);
600+
ShapeNHWC odim(dims(ctx));
595601

596602
auto windowSize = 2 * halfWindowSize_ + 1;
597603
auto normedAlpha = alpha_ / windowSize;

0 commit comments

Comments
 (0)