Skip to content

Commit deaaa59

Browse files
committed
IR: Move the last layer into the new IR implementation.
Summary: IR: Move the last layer into the new IR implementation. This commit moves the LRN node, which was implemented in the old network into the new IR. The code and grad checks are just copied as is into the new IR with minor interface changes. The tests seem to pass. Test Plan: Updated the tests to use the IR interface. Reviewers: #glow, meghanl, abdulras Reviewed By: abdulras Subscribers: #glow Differential Revision: https://phabricator.intern.facebook.com/D5940118 Tags: none Signature: 5940118:1506703345:242e78cb21da6bb6a7ee5ca96c8eec7439978608
1 parent 919e904 commit deaaa59

File tree

7 files changed

+250
-0
lines changed

7 files changed

+250
-0
lines changed

include/glow/IR/IRBuilder.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ class IRBuilder {
5555
float epsilon = 1e-5,
5656
float momentum = 0.9);
5757

58+
LocalResponseNormalizationInst *
59+
createLocalResponseNormalizationOp(Value *input, size_t halfWindowSize = 2,
60+
float alpha = 1e-4, float beta = 0.75,
61+
float k = 2.0);
62+
5863
ArithmeticInst *createArithmeticOp(Value *LHS, Value *RHS,
5964
ArithmeticInst::OpKind op);
6065

@@ -101,6 +106,11 @@ class IRBuilder {
101106
Value *dest, Value *src, Value *scale, Value *bias, Value *mean,
102107
Value *var, size_t channelIdx, float epsilon, float momentum);
103108

109+
LocalResponseNormalizationInst *
110+
createLocalResponseNormalizationInst(Value *dest, Value *src, Value *scale,
111+
size_t halfWindowSize, float alpha,
112+
float beta, float k);
113+
104114
ArithmeticInst *createArithmeticInst(Value *dest, Value *LHS, Value *RHS,
105115
ArithmeticInst::OpKind kind);
106116

include/glow/IR/Instrs.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ DEF_INSTR(TransposeInst, transpose)
1414
DEF_INSTR(ReshapeInst, reshape)
1515
DEF_INSTR(ConcatInst, concat)
1616
DEF_INSTR(BatchNormalizationInst, batchnormalization)
17+
DEF_INSTR(LocalResponseNormalizationInst, localresponsenormalization)
1718
DEF_INSTR(ArithmeticInst, arithmetic)
1819
DEF_VALUE(WeightVar, weight)
1920

include/glow/IR/Instrs.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,45 @@ class ArithmeticInst : public Instruction {
355355
void verify() const;
356356
};
357357

358+
class LocalResponseNormalizationInst : public Instruction {
359+
/// The number of neighbouring channels on each side to sum over
360+
size_t halfWindowSize_;
361+
362+
/// The scaling parameter
363+
float alpha_;
364+
365+
/// The exponent parameter
366+
float beta_;
367+
368+
/// The offset parameter
369+
float k_;
370+
371+
public:
372+
LocalResponseNormalizationInst(Value *dest, Value *src, Value *scale,
373+
size_t halfWindowSize, float alpha, float beta,
374+
float k)
375+
: Instruction(Kinded::Kind::LocalResponseNormalizationInstKind,
376+
dest->getType(),
377+
{{dest, OperandKind::kOut},
378+
{src, OperandKind::kIn},
379+
{scale, OperandKind::kInOut}}),
380+
halfWindowSize_(halfWindowSize), alpha_(alpha), beta_(beta), k_(k) {}
381+
382+
static bool classof(const Kinded *k) {
383+
return k->getKind() == Kinded::Kind::LocalResponseNormalizationInstKind;
384+
}
385+
std::string getExtraDesc() const;
386+
Value *getDest() const { return getOperand(0).first; }
387+
Value *getSrc() const { return getOperand(1).first; }
388+
Value *getScale() const { return getOperand(2).first; }
389+
390+
size_t gethalfWindowSize() const { return halfWindowSize_; }
391+
float getAlpha() const { return alpha_; }
392+
float getBeta() const { return beta_; }
393+
float getK() const { return k_; }
394+
void verify() const;
395+
};
396+
358397
class WeightVar : public Value {
359398
public:
360399
enum class InitKind {

src/glow/IR/IRBuilder.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,17 @@ BatchNormalizationInst *IRBuilder::createBatchNormalizationOp(Value *input,
174174
channelIdx, epsilon, momentum);
175175
}
176176

177+
LocalResponseNormalizationInst *IRBuilder::createLocalResponseNormalizationOp(
178+
Value *input, size_t halfWindowSize, float alpha, float beta, float k) {
179+
auto Ty = input->getType();
180+
auto *scale = createAllocActivationInst(Ty, "scale");
181+
182+
// The output tensor is of the same shape as the input tensor.
183+
auto *res = createAllocActivationInst(Ty);
184+
return createLocalResponseNormalizationInst(input, res, scale, halfWindowSize,
185+
alpha, beta, k);
186+
}
187+
177188
ArithmeticInst *IRBuilder::createArithmeticOp(Value *LHS, Value *RHS,
178189
ArithmeticInst::OpKind op) {
179190
assert(LHS->dims() == RHS->dims() && "Invalid operand shapes");
@@ -281,6 +292,15 @@ BatchNormalizationInst *IRBuilder::createBatchNormalizationInst(
281292
return A;
282293
}
283294

295+
LocalResponseNormalizationInst *IRBuilder::createLocalResponseNormalizationInst(
296+
Value *dest, Value *src, Value *scale, size_t halfWindowSize, float alpha,
297+
float beta, float k) {
298+
auto *A = new LocalResponseNormalizationInst(dest, src, scale, halfWindowSize,
299+
alpha, beta, k);
300+
M_.pushInstr(A);
301+
return A;
302+
}
303+
284304
ArithmeticInst *IRBuilder::createArithmeticInst(Value *dest, Value *LHS,
285305
Value *RHS,
286306
ArithmeticInst::OpKind kind) {

src/glow/IR/Instrs.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ std::string BatchNormalizationInst::getExtraDesc() const {
6969
return listToString(channelIdx_, epsilon_, momentum_);
7070
}
7171

72+
std::string LocalResponseNormalizationInst::getExtraDesc() const {
73+
return listToString(halfWindowSize_, alpha_, beta_, k_);
74+
}
75+
7276
const char *ArithmeticInst::getKindStr() const {
7377
const char *names[] = {"add", "mul", nullptr};
7478
return names[static_cast<int>(kind_)];
@@ -189,6 +193,7 @@ void SoftMaxInst::verify() const {
189193
}
190194
void RegressionInst::verify() const {
191195
checkSameType(getOperand(0), getOperand(1));
196+
checkSameType(getOperand(0), getOperand(2));
192197
}
193198

194199
void ReshapeInst::verify() const {
@@ -240,6 +245,10 @@ void BatchNormalizationInst::verify() const {
240245
assert(getOperand(4).first->getType()->dims() == exp && "Invalid mean dim");
241246
assert(getOperand(5).first->getType()->dims() == exp && "Invalid var dim");
242247
}
248+
void LocalResponseNormalizationInst::verify() const {
249+
checkSameType(getOperand(0), getOperand(1));
250+
checkSameType(getOperand(0), getOperand(2));
251+
}
243252
void ArithmeticInst::verify() const {
244253
checkSameType(getOperand(0), getOperand(1));
245254
checkSameType(getOperand(0), getOperand(2));

src/glow/Interpreter/InterpreterNodes.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,141 @@ void Interpreter::bwdBatchNormalizationInst(Context *ctx,
845845
}
846846
}
847847

848+
void Interpreter::fwdLocalResponseNormalizationInst(
849+
glow::Context *ctx, bool isTrain,
850+
const glow::LocalResponseNormalizationInst *I) {
851+
auto inW = getWeightHandle(ctx, I->getSrc());
852+
auto outW = getWeightHandle(ctx, I->getDest());
853+
auto scaleCache = getWeightHandle(ctx, I->getScale());
854+
855+
ShapeNHWC odim = outW.dims();
856+
ShapeNHWC idim = inW.dims();
857+
(void)odim;
858+
859+
// LRN node does not change the shape of the input.
860+
assert(odim == idim && "Output of LRN node must be same shape as input");
861+
862+
// LRN node normalizes across channels, so the input must have a minimum
863+
// depth of 1.
864+
assert(idim.c > 0 && "Input of LRN node must have a minimum depth of 1");
865+
866+
auto halfWindowSize = I->gethalfWindowSize();
867+
auto k = I->getK();
868+
auto beta = I->getBeta();
869+
auto windowSize = 2 * halfWindowSize + 1;
870+
auto normedAlpha = I->getAlpha() / windowSize;
871+
872+
// For every input in the batch:
873+
for (size_t n = 0; n < idim.n; n++) {
874+
875+
// For every row:
876+
for (size_t h = 0; h < idim.h; h++) {
877+
878+
// For every column:
879+
for (size_t w = 0; w < idim.w; w++) {
880+
881+
FloatTy squareSum = 0.0;
882+
883+
// Compute squareSum for first channel.
884+
for (size_t c = 1; c <= halfWindowSize && c < idim.c; c++) {
885+
auto val = inW.at({n, h, w, c});
886+
squareSum += (val * val);
887+
}
888+
889+
// For every channel:
890+
for (size_t c = 0; c < idim.c; c++) {
891+
auto scale = k + normedAlpha * squareSum;
892+
893+
// This will be used to accelerate the backward pass.
894+
scaleCache.at({n, h, w, c}) = scale;
895+
896+
auto normFactor = std::pow(scale, -beta);
897+
outW.at({n, h, w, c}) = inW.at({n, h, w, c}) * normFactor;
898+
899+
// Modify squareSum for next channel.
900+
auto subIndex = c - halfWindowSize;
901+
auto addIndex = c + halfWindowSize + 1;
902+
auto sub = (c >= halfWindowSize) ? inW.at({n, h, w, subIndex}) : 0;
903+
auto add = (addIndex < idim.c) ? inW.at({n, h, w, addIndex}) : 0;
904+
905+
// Subtract out "rear" end of this window, add "front" end of next.
906+
squareSum = squareSum - (sub * sub) + (add * add);
907+
}
908+
}
909+
}
910+
}
911+
}
912+
913+
void Interpreter::bwdLocalResponseNormalizationInst(
914+
glow::Context *ctx, const glow::LocalResponseNormalizationInst *I) {
915+
auto inW = getWeightHandle(ctx, I->getSrc());
916+
auto inG = getGradHandle(ctx, I->getSrc());
917+
auto outW = getWeightHandle(ctx, I->getDest());
918+
auto outG = getGradHandle(ctx, I->getDest());
919+
auto scaleCache = getWeightHandle(ctx, I->getScale());
920+
921+
ShapeNHWC odim = outW.dims();
922+
923+
auto halfWindowSize = I->gethalfWindowSize();
924+
auto beta = I->getBeta();
925+
auto windowSize = 2 * halfWindowSize + 1;
926+
auto normedAlpha = I->getAlpha() / windowSize;
927+
928+
// For every input in the batch:
929+
for (size_t n = 0; n < odim.n; n++) {
930+
931+
// For every row:
932+
for (size_t h = 0; h < odim.h; h++) {
933+
934+
// For every column:
935+
for (size_t w = 0; w < odim.w; w++) {
936+
937+
FloatTy sum = 0.0;
938+
939+
// Compute sum for first channel.
940+
for (size_t c = 1; c <= halfWindowSize && c < odim.c; c++) {
941+
auto outw = outW.at({n, h, w, c});
942+
auto scale = scaleCache.at({n, h, w, c});
943+
auto outg = outG.at({n, h, w, c});
944+
sum += (outg * (outw / scale));
945+
}
946+
947+
// For every channel:
948+
for (size_t c = 0; c < odim.c; c++) {
949+
auto outg = outG.at({n, h, w, c});
950+
auto scale = scaleCache.at({n, h, w, c});
951+
auto inw = inW.at({n, h, w, c});
952+
953+
inG.at({n, h, w, c}) = outg * std::pow(scale, -beta) -
954+
2 * normedAlpha * beta * inw * sum;
955+
956+
// Modify sum for next channel.
957+
auto subIndex = c - halfWindowSize;
958+
auto addIndex = c + halfWindowSize + 1;
959+
960+
if (c >= halfWindowSize) {
961+
auto outw = outW.at({n, h, w, subIndex});
962+
auto scale = scaleCache.at({n, h, w, subIndex});
963+
auto outg = outG.at({n, h, w, subIndex});
964+
965+
// Subtract "rear" end of this window.
966+
sum -= (outg * (outw / scale));
967+
}
968+
969+
if (addIndex < odim.c) {
970+
auto outw = outW.at({n, h, w, addIndex});
971+
auto scale = scaleCache.at({n, h, w, addIndex});
972+
auto outg = outG.at({n, h, w, addIndex});
973+
974+
// Add "front" end of next window.
975+
sum += (outg * (outw / scale));
976+
}
977+
}
978+
}
979+
}
980+
}
981+
}
982+
848983
//===----------------------------------------------------------------------===//
849984
// Arithmetic operations
850985
//===----------------------------------------------------------------------===//

tests/unittests/IRGradCheck.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,42 @@ TEST(Network, gradientCheck_batchNorm) {
232232
performGradCheck(IP, RN, A, Ex, &inputs, &outputs, 0.001, 0.004);
233233
}
234234

235+
TEST(Network, gradientCheck_LRN) {
236+
Interpreter IP;
237+
IP.getConfig().maxNumThreads = 1;
238+
239+
size_t numDim = 8;
240+
size_t numOutputElem = numDim;
241+
242+
Value *A;
243+
Value *Ex;
244+
Instruction *RN;
245+
{
246+
IRBuilder bb(IP.getModule());
247+
248+
A = bb.createWeightVar(ElemKind::FloatTy, {1, numDim, numDim, 3});
249+
Ex = bb.createWeightVar(ElemKind::FloatTy, {1, numOutputElem});
250+
251+
Instruction *O = bb.createLocalResponseNormalizationOp(A, 3, 0.0001, 0.9);
252+
O = bb.createFullyConnectedOp(*O, numOutputElem);
253+
RN = bb.createRegressionOp(*O, Ex);
254+
}
255+
256+
IP.getModule().verify();
257+
IP.initVars();
258+
259+
Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 3});
260+
Tensor outputs(ElemKind::FloatTy, {1, numOutputElem});
261+
262+
auto inputsH = inputs.getHandle<FloatTy>();
263+
auto outputsH = outputs.getHandle<FloatTy>();
264+
265+
inputsH.randomize(1);
266+
outputsH.randomize(1);
267+
268+
performGradCheck(IP, RN, A, Ex, &inputs, &outputs, 0.001, 0.004);
269+
}
270+
235271
TEST(Network, gradientCheck_Arithmetic) {
236272
Interpreter IP;
237273
IP.getConfig().maxNumThreads = 1;

0 commit comments

Comments
 (0)