diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index 26bb912d33..a96cd8df7c 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -254,8 +254,8 @@ class Function final : public Named { unsigned_t pad); FullyConnectedNode *createFullyConnected(llvm::StringRef name, - NodeValue input, Variable *W, - Variable *B); + NodeValue input, Storage *W, + Storage *B); /// Create a fully connected node with the specified output type. /// Note, outputDepth is infered based on the output type. @@ -611,6 +611,20 @@ class Function final : public Named { SaveNode *createSave(Context &ctx, llvm::StringRef name, NodeValue input); + void createSimpleRNN(Context &ctx, llvm::StringRef namePrefix, + const llvm::ArrayRef inputs, unsigned batchSize, + unsigned hiddenSize, unsigned outputSize, + std::vector &outputs); + + void createGRU(Context &ctx, llvm::StringRef namePrefix, + const llvm::ArrayRef inputs, unsigned batchSize, + unsigned hiddenSize, unsigned outputSize, + std::vector &outputs); + + void createLSTM(Context &ctx, llvm::StringRef namePrefix, + const llvm::ArrayRef inputs, unsigned batchSize, + unsigned hiddenSize, unsigned outputSize, + std::vector &outputs); /// @} /// Erase the node \p N from the Function. diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index 9d4d6bcdb3..8f51da5a97 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -587,8 +587,8 @@ AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input, } FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name, - NodeValue input, Variable *W, - Variable *B) { + NodeValue input, Storage *W, + Storage *B) { TypeRef T = input.getType(); TypeRef OT = getParent()->uniqueTypeWithNewShape( T, {input.dims()[0], B->getType()->dims()[0]}); @@ -1998,6 +1998,379 @@ SaveNode *Function::createSave(Context &ctx, llvm::StringRef name, return addNode(new SaveNode(name, input, dest)); } +void Function::createGRU(Context &ctx, llvm::StringRef namePrefix, + llvm::ArrayRef inputs, unsigned batchSize, + unsigned hiddenSize, unsigned outputSize, + std::vector &outputs) { + std::string nameBase = namePrefix; + const unsigned timeSteps = inputs.size(); + assert(timeSteps > 0 && "empty input"); + const unsigned inputSize = inputs.front()->dims(0).back(); + assert(inputSize > 0 && "input dimensionality is zero"); + + // Initialize the state to zero. + Placeholder *HInit = getParent()->createPlaceholder( + ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_state", false); + ctx.allocate(HInit)->zero(); + Node *Ht = HInit; + + // Update gate: + // Z <- sigmoid(Wxz * x + Whz * h + bz) + // Reset gate: + // R <- sigmoid(Wxr * x + Whr * h + br) + // Hidden state: + // h <- Z . h + (1 - Z) tanh (Wxh * x + Whh * (R . h) + bh) + + // update gate + float bUpdate = 0.1; + Placeholder *Wxz = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxz", true); + Placeholder *Whz = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whz", true); + Placeholder *Bz1 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz1", true); + Placeholder *Bz2 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz2", true); + + ctx.allocate(Wxz)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Whz)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bz1)->init(glow::Tensor::InitKind::Broadcast, bUpdate, + getPRNG()); + ctx.allocate(Bz2)->init(glow::Tensor::InitKind::Broadcast, bUpdate, + getPRNG()); + + // Reset gate. + float bReset = -1.0; + Placeholder *Wxr = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxr", true); + Placeholder *Whr = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whr", true); + Placeholder *Br1 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".br1", true); + Placeholder *Br2 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".br2", true); + + ctx.allocate(Wxr)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Whr)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Br1)->init(glow::Tensor::InitKind::Broadcast, bReset, getPRNG()); + ctx.allocate(Br2)->init(glow::Tensor::InitKind::Broadcast, bReset, getPRNG()); + + // hidden state + float b = 0.1; + Placeholder *Wxh = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true); + Placeholder *Whh = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true); + Placeholder *Bh1 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh1", true); + Placeholder *Bh2 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh2", true); + + ctx.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bh1)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); + ctx.allocate(Bh2)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); + + // Output Layer. + Placeholder *Why = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true); + Placeholder *By = getParent()->createPlaceholder( + ElemKind::FloatTy, {outputSize}, nameBase + ".by", true); + + ctx.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); + + auto ty = getParent()->uniqueType(ElemKind::FloatTy, {batchSize, hiddenSize}); + auto *Ones = createSplat(nameBase + ".ones", ty, 1.0); + + std::vector outputNodes; + for (unsigned t = 0; t < timeSteps; t++) { + auto fc1Name = nameBase + ".fc1." + std::to_string(t); + auto fc2Name = nameBase + ".fc2." + std::to_string(t); + auto add1Name = nameBase + ".add1." + std::to_string(t); + auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t); + + auto *Zt = createSigmoid( + sigmoid1Name, + createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whz, Bz1), + createFullyConnected(fc2Name, inputs[t], Wxz, Bz2))); + + auto fc3Name = nameBase + ".fc3." + std::to_string(t); + auto fc4Name = nameBase + ".fc4." + std::to_string(t); + auto add2Name = nameBase + ".add2." + std::to_string(t); + auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t); + + auto *Rt = createSigmoid( + sigmoid2Name, + createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whr, Br1), + createFullyConnected(fc4Name, inputs[t], Wxr, Br2))); + + auto zhtName = nameBase + ".zh." + std::to_string(t); + auto *ZHt = createMul(zhtName, Zt, Ht); + + auto oneMinusZtName = nameBase + ".1-z." + std::to_string(t); + auto *OneMinusZt = createSub(oneMinusZtName, Ones, Zt); + + auto rhtName = nameBase + ".rh." + std::to_string(t); + auto *RHt = createMul(rhtName, Rt, Ht); + + auto fc5Name = nameBase + ".fc5." + std::to_string(t); + auto fc6Name = nameBase + ".fc6." + std::to_string(t); + auto add3Name = nameBase + ".add3." + std::to_string(t); + auto tanh1Name = nameBase + ".tanh1." + std::to_string(t); + + auto *Ut = createTanh( + tanh1Name, + createAdd(add3Name, createFullyConnected(fc5Name, RHt, Whh, Bh1), + createFullyConnected(fc6Name, inputs[t], Wxh, Bh2))); + + auto oneMinusZtUtName = nameBase + "1.-zu." + std::to_string(t); + auto *OneMinusZtUt = createMul(oneMinusZtUtName, OneMinusZt, Ut); + + auto htName = nameBase + ".H." + std::to_string(t); + Ht = createAdd(htName, ZHt, OneMinusZtUt); + + auto outName = nameBase + ".out." + std::to_string(t); + auto *O = createFullyConnected(outName, Ht, Why, By); + outputs.push_back(O); + } +} + +void Function::createSimpleRNN(Context &ctx, llvm::StringRef namePrefix, + llvm::ArrayRef inputs, + unsigned batchSize, unsigned hiddenSize, + unsigned outputSize, + std::vector &outputs) { + std::string nameBase = namePrefix; + const unsigned timeSteps = inputs.size(); + assert(timeSteps > 0 && "empty input"); + const unsigned inputSize = inputs.front()->dims(0).back(); + assert(inputSize > 0 && "input dimensionality is zero"); + + // Initialize the state to zero. + Placeholder *HInit = + getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize}, + nameBase + ".initial_state", false); + ctx.allocate(HInit)->zero(); + Node *Ht = HInit; + + float b = 0.1; + Placeholder *Whh = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true); + Placeholder *Bhh = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bhh", true); + Placeholder *Wxh = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true); + + Placeholder *Bxh = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bxh", true); + Placeholder *Why = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true); + Placeholder *Bhy = getParent()->createPlaceholder( + ElemKind::FloatTy, {outputSize}, nameBase + ".Bhy", true); + + ctx.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bhh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); + ctx.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Bxh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); + ctx.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bhy)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); + + // Un-roll backpropogation through time as a loop with the shared parameters. + for (unsigned t = 0; t < timeSteps; t++) { + auto fc1Name = nameBase + ".fc1." + std::to_string(t); + auto *FC1 = createFullyConnected(fc1Name, Ht, Whh, Bhh); + auto fc2Name = nameBase + ".fc2." + std::to_string(t); + auto *FC2 = createFullyConnected(fc2Name, inputs[t], Wxh, Bxh); + auto aName = nameBase + ".add." + std::to_string(t); + auto *A = createAdd(aName, FC1, FC2); + auto tanhName = nameBase + ".tanh." + std::to_string(t); + auto *H = createTanh(tanhName, A); + auto outName = nameBase + ".out." + std::to_string(t); + auto *O = createFullyConnected(outName, H, Why, Bhy); + outputs.push_back(O); + + Ht = H; + }; +} + +void Function::createLSTM(Context &ctx, llvm::StringRef namePrefix, + llvm::ArrayRef inputs, unsigned batchSize, + unsigned hiddenSize, unsigned outputSize, + std::vector &outputs) { + std::string nameBase = namePrefix; + const unsigned timeSteps = inputs.size(); + assert(timeSteps > 0 && "empty input"); + const unsigned inputSize = inputs.front()->dims(0).back(); + assert(inputSize > 0 && "input dimensionality is zero"); + + // Initialize the hidden and cell states to zero. + Placeholder *HInit = + getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize}, + "initial_hidden_state", false); + ctx.allocate(HInit)->zero(); + Node *Ht = HInit; + + Placeholder *CInit = getParent()->createPlaceholder( + ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_cell_state", false); + ctx.allocate(CInit)->zero(); + Node *Ct = CInit; + + // Forget gate: + // F <- sigmoid(Wxf * x + Whf * h + bf) + // Input gate: + // I <- sigmoid(Wxi * x + Whi * h + bi) + // Output gate: + // O <- sigmoid(Wxo * x + Who * h + bi) + // Cell state: + // C <- F . C + I . tanh(Wxc * x + Whc * h + bc) + // Hidden state: + // h <- O . tanh(C) + + // forget gate + float bForget = 1.0; + Placeholder *Wxf = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxf", true); + Placeholder *Whf = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whf", true); + Placeholder *Bf1 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf1", true); + Placeholder *Bf2 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf2", true); + ctx.allocate(Wxf)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Whf)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bf1)->init(glow::Tensor::InitKind::Broadcast, bForget, + getPRNG()); + ctx.allocate(Bf2)->init(glow::Tensor::InitKind::Broadcast, bForget, + getPRNG()); + + // input gate + float bInput = 0.1; + Placeholder *Wxi = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxi", true); + Placeholder *Whi = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whi", true); + Placeholder *Bi1 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi1", true); + Placeholder *Bi2 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi2", true); + + ctx.allocate(Wxi)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Whi)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bi1)->init(glow::Tensor::InitKind::Broadcast, bInput, getPRNG()); + ctx.allocate(Bi2)->init(glow::Tensor::InitKind::Broadcast, bInput, getPRNG()); + + // output gate + float bOutput = 0.1; + Placeholder *Wxo = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxo", true); + Placeholder *Who = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Who", true); + Placeholder *Bo1 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo1", true); + Placeholder *Bo2 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo2", true); + + ctx.allocate(Wxo)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Who)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bo1)->init(glow::Tensor::InitKind::Broadcast, bOutput, + getPRNG()); + ctx.allocate(Bo2)->init(glow::Tensor::InitKind::Broadcast, bOutput, + getPRNG()); + + // cell state + float bCell = 0.1; + Placeholder *Wxc = getParent()->createPlaceholder( + ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxc", true); + Placeholder *Whc = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whc", true); + Placeholder *Bc1 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc1", true); + Placeholder *Bc2 = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc2", true); + + ctx.allocate(Wxc)->init(glow::Tensor::InitKind::Xavier, inputSize, getPRNG()); + ctx.allocate(Whc)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(Bc1)->init(glow::Tensor::InitKind::Broadcast, bCell, getPRNG()); + ctx.allocate(Bc2)->init(glow::Tensor::InitKind::Broadcast, bCell, getPRNG()); + + // output layer + float b = 0.1; + Placeholder *Why = getParent()->createPlaceholder( + ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true); + Placeholder *By = getParent()->createPlaceholder( + ElemKind::FloatTy, {outputSize}, nameBase + ".by", true); + + ctx.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize, + getPRNG()); + ctx.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG()); + + std::vector outputNodes; + for (unsigned t = 0; t < timeSteps; t++) { + auto fc1Name = nameBase + ".fc1." + std::to_string(t); + auto fc2Name = nameBase + ".fc2." + std::to_string(t); + auto add1Name = nameBase + ".add1." + std::to_string(t); + auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t); + + auto *Ft = createSigmoid( + sigmoid1Name, + createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whf, Bf1), + createFullyConnected(fc2Name, inputs[t], Wxf, Bf2))); + + auto fc3Name = nameBase + ".fc3." + std::to_string(t); + auto fc4Name = nameBase + ".fc4." + std::to_string(t); + auto add2Name = nameBase + ".add2." + std::to_string(t); + auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t); + + auto *It = createSigmoid( + sigmoid2Name, + createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whi, Bi1), + createFullyConnected(fc4Name, inputs[t], Wxi, Bi2))); + + auto fc5Name = nameBase + ".fc5." + std::to_string(t); + auto fc6Name = nameBase + ".fc6." + std::to_string(t); + auto add3Name = nameBase + ".add3." + std::to_string(t); + auto sigmoid3Name = nameBase + ".sigmoid3." + std::to_string(t); + + auto *Ot = createSigmoid( + sigmoid3Name, + createAdd(add3Name, createFullyConnected(fc5Name, Ht, Who, Bo1), + createFullyConnected(fc6Name, inputs[t], Wxo, Bo2))); + + auto fc7Name = nameBase + ".fc7." + std::to_string(t); + auto fc8Name = nameBase + ".fc8." + std::to_string(t); + auto add4Name = nameBase + ".add4." + std::to_string(t); + auto tanh1Name = nameBase + ".tanh1." + std::to_string(t); + + auto *CRt = createTanh( + tanh1Name, + createAdd(add4Name, createFullyConnected(fc7Name, Ht, Whc, Bc1), + createFullyConnected(fc8Name, inputs[t], Wxc, Bc2))); + + auto mul1Name = nameBase + ".mul1." + std::to_string(t); + auto mul2Name = nameBase + ".mul2." + std::to_string(t); + Ct = createAdd(nameBase + ".C." + std::to_string(t), + createMul(mul1Name, Ft, Ct), createMul(mul2Name, It, CRt)); + + auto htName = nameBase + ".H." + std::to_string(t); + auto tanh2Name = nameBase + ".tanh2." + std::to_string(t); + Ht = createMul(htName, Ot, createTanh(tanh2Name, Ct)); + + auto outName = nameBase + ".out." + std::to_string(t); + auto *O = createFullyConnected(outName, Ht, Why, By); + outputs.push_back(O); + } +}; //===----------------------------------------------------------------------===// // Graph dumping and printing //===----------------------------------------------------------------------===// diff --git a/tests/unittests/MLTest.cpp b/tests/unittests/MLTest.cpp index 1a1921c145..7d1eb4755f 100644 --- a/tests/unittests/MLTest.cpp +++ b/tests/unittests/MLTest.cpp @@ -547,26 +547,26 @@ TEST_P(MLTest, learnSingleValueConcat) { EXPECT_NEAR(RNWH.at({0, 0}), 0.9, 0.1); } -void buildGRU(Function *F, const std::vector &slicesX, +void buildGRU(Context &ctx, Function *F, const std::vector &slicesX, unsigned hiddenSize, unsigned outputSize, std::vector &outputs) { - return F->createGRU("GRU", slicesX, 1, hiddenSize, outputSize, outputs); + return F->createGRU(ctx, "GRU", slicesX, 1, hiddenSize, outputSize, outputs); }; -void buildRNN(Function *F, const std::vector &slicesX, +void buildRNN(Context &ctx, Function *F, const std::vector &slicesX, unsigned hiddenSize, unsigned outputSize, std::vector &outputs) { - return F->createSimpleRNN("SimpleRNN", slicesX, 1, hiddenSize, outputSize, + return F->createSimpleRNN(ctx, "SimpleRNN", slicesX, 1, hiddenSize, outputSize, outputs); }; -void buildLSTM(Function *F, const std::vector &slicesX, +void buildLSTM(Context &ctx, Function *F, const std::vector &slicesX, unsigned hiddenSize, unsigned outputSize, std::vector &outputs) { - return F->createLSTM("LSTM", slicesX, 1, hiddenSize, outputSize, outputs); + return F->createLSTM(ctx, "LSTM", slicesX, 1, hiddenSize, outputSize, outputs); }; -using TCellGenerator = void (*)(Function *, const std::vector &, +using TCellGenerator = void (*)(Context &, Function *, const std::vector &, unsigned, unsigned, std::vector &); void testRNNCell(TCellGenerator cell) { @@ -588,10 +588,12 @@ void testRNNCell(TCellGenerator cell) { const unsigned NumElements = 4; // Create a variable with 1 input, which is 3 consecutive vectors // of 4 elements each. - auto *X = mod.createVariable(ElemKind::FloatTy, {1, NumVectors, NumElements}, - "X", VisibilityKind::Public, false); - auto *Y = mod.createVariable(ElemKind::FloatTy, {1, NumVectors}, "Y", - VisibilityKind::Public, false); + Placeholder *X = mod.createPlaceholder(ElemKind::FloatTy, {1, NumVectors, NumElements}, + "X", false); + Placeholder *Y = mod.createPlaceholder(ElemKind::FloatTy, {1, NumVectors}, "Y", + false); + ctx.allocate(X); + ctx.allocate(Y); // Extract a slice for each input. std::vector XSliced; @@ -615,7 +617,7 @@ void testRNNCell(TCellGenerator cell) { const unsigned outputSize = 1; std::vector outputNodes; - cell(F, XSliced, hiddenSize, outputSize, outputNodes); + cell(ctx, F, XSliced, hiddenSize, outputSize, outputNodes); std::vector regressionNodes; for (unsigned t = 0; t < NumVectors; t++) { @@ -624,7 +626,9 @@ void testRNNCell(TCellGenerator cell) { }; auto *R = F->createConcat("O", regressionNodes, 1); - auto *result = F->createSave("result", R); + SaveNode *result = F->createSave(ctx, "result", R); + + Tensor *res = ctx.allocate(result->getPlaceholder()); Function *TF = glow::differentiate(F, TC); EE.compile(CompilationMode::Train, TF, ctx); @@ -640,15 +644,15 @@ void testRNNCell(TCellGenerator cell) { } // Train the network. Learn 1000 batches. - runBatch(EE, 1000, sampleCounter, {X, Y}, {&inputs, &expected}); + runBatch(EE, ctx, 1000, sampleCounter, {X, Y}, {&inputs, &expected}); // Testing the output vector. EE.compile(CompilationMode::Infer, F, ctx); - updateVariables({X}, {&inputs}); + updateVariables(ctx, {X}, {&inputs}); EE.run(); - auto RNWH = result->getVariable()->getPayload().getHandle<>(); + auto RNWH = res->getHandle<>(); (void)RNWH; // Test the output: @@ -663,41 +667,6 @@ TEST_P(MLTest, trainGRU) { testRNNCell(buildGRU); }; TEST_P(MLTest, trainLSTM) { testRNNCell(buildLSTM); }; -/// Learn the square root of two. -TEST_P(MLTest, learnSqrt2) { - TrainingConfig TC; - Context ctx; - - TC.learningRate = 0.03; - - auto &mod = EE_.getModule(); - Function *F = mod.createFunction("Square root of 2"); - - auto *A = mod.createVariable(ElemKind::FloatTy, {1}, "A", - VisibilityKind::Public, true); - A->getPayload().init(Tensor::InitKind::Broadcast, 1, mod.getPRNG()); - - auto *Ex = mod.createVariable(ElemKind::FloatTy, {1}, "Ex", - VisibilityKind::Public, false); - Ex->getPayload().getHandle() = {2}; - - Node *O = F->createMul("Mult", A, A); - O = F->createRegression("reg", O, Ex); - F->createSave("ret", O); - - Function *TF = glow::differentiate(F, TC); - EE_.compile(CompilationMode::Train, TF, ctx); - - // Train the network: - for (int i = 0; i < 50; i++) { - EE_.run(); - EE_.run(); - } - - float res = A->getPayload().getHandle().at({0}); - EXPECT_NEAR(res, 1.4142, 0.01); -} - TEST_P(MLTest, trainSimpleLinearRegression) { TrainingConfig TC; Context ctx; @@ -731,28 +700,32 @@ TEST_P(MLTest, trainSimpleLinearRegression) { } // Create a variable with 1 input, which is a real number. - auto *inputX = mod.createVariable(ElemKind::FloatTy, {numSamples, 1}, "input", - VisibilityKind::Public, false); - auto *expectedY = - mod.createVariable(ElemKind::FloatTy, {numSamples, 1}, "expected", - VisibilityKind::Public, false); + Placeholder *inputX = + mod.createPlaceholder(ElemKind::FloatTy, {numSamples, 1}, "input", false); + Placeholder *expectedY = mod.createPlaceholder( + ElemKind::FloatTy, {numSamples, 1}, "expected", false); - FullyConnectedNode *FC = F->createFullyConnected("fc", inputX, 1); + FullyConnectedNode *FC = F->createFullyConnected(ctx, "fc", inputX, 1); Node *R = F->createRegression("reg", FC, expectedY); - F->createSave("return", R); + SaveNode *SN = F->createSave(ctx, "return", R); - Variable *M = llvm::cast(FC->getWeights()); - Variable *B = llvm::cast(FC->getBias()); + ctx.allocate(inputX); + ctx.allocate(expectedY); + ctx.allocate(SN->getPlaceholder()); + + Placeholder *M = llvm::cast(FC->getWeights()); + Placeholder *B = llvm::cast(FC->getBias()); Function *TF = glow::differentiate(F, TC); EE_.compile(CompilationMode::Train, TF, ctx); // Train the network doing 100 steps. Learn on 500 samples. - runBatch(EE_, 100, sampleCounter, {inputX, expectedY}, {&tensorX, &tensorY}); + runBatch(EE_, ctx, 100, sampleCounter, {inputX, expectedY}, + {&tensorX, &tensorY}); // Testing trained m and b: - EXPECT_NEAR(M->getPayload().getHandle<>().at({0, 0}), referenceM, 0.01); - EXPECT_NEAR(B->getPayload().getHandle<>().at({0}), referenceB, 0.01); + EXPECT_NEAR(ctx.get(M)->getHandle<>().at({0, 0}), referenceM, 0.01); + EXPECT_NEAR(ctx.get(B)->getHandle<>().at({0}), referenceB, 0.01); } enum class Sport : size_t { BASKETBALL = 0, SOCCER = 1 }; @@ -803,15 +776,18 @@ TEST_P(MLTest, classifyPlayerSport) { auto &mod = EE_.getModule(); Function *F = mod.createFunction("classifyPlayers"); - auto *A = - mod.createVariable(ElemKind::FloatTy, {numTrainPlayers, numFeatures}, "A", - VisibilityKind::Public, false); - auto *S = mod.createVariable(ElemKind::Int64ITy, {numTrainPlayers, 1}, "S", - VisibilityKind::Public, false); + Placeholder *A = mod.createPlaceholder( + ElemKind::FloatTy, {numTrainPlayers, numFeatures}, "A", false); + Placeholder *S = mod.createPlaceholder(ElemKind::Int64ITy, + {numTrainPlayers, 1}, "S", false); - auto *FC = F->createFullyConnected("fc", A, numClasses); + auto *FC = F->createFullyConnected(ctx, "fc", A, numClasses); auto *SM = F->createSoftMax("softmax", FC, S); - auto *result = F->createSave("result", SM); + SaveNode *result = F->createSave(ctx, "result", SM); + + ctx.allocate(A); + ctx.allocate(S); + ctx.allocate(result->getPlaceholder()); Function *TF = glow::differentiate(F, TC); EE_.compile(CompilationMode::Train, TF, ctx); @@ -821,7 +797,7 @@ TEST_P(MLTest, classifyPlayerSport) { generatePlayerData(players, labels, numTrainPlayers, mod.getPRNG()); // Training: - runBatch(EE_, 2000, sampleCounter, {A, S}, {&players, &labels}); + runBatch(EE_, ctx, 2000, sampleCounter, {A, S}, {&players, &labels}); EE_.compile(CompilationMode::Infer, F, ctx); @@ -839,11 +815,11 @@ TEST_P(MLTest, classifyPlayerSport) { testPlayersTensor.getHandle<>().at({i, 1}) = std::get<1>(testPlayers[i]); } - updateVariables({A}, {&testPlayersTensor}); + updateVariables(ctx, {A}, {&testPlayersTensor}); EE_.run(); + auto SMH = ctx.get(result->getPlaceholder())->getHandle<>(); for (size_t i = 0; i < testPlayers.size(); i++) { - auto SMH = result->getVariable()->getPayload().getHandle<>(); const size_t sport = static_cast(std::get<2>(testPlayers[i])); EXPECT_NEAR(SMH.at({i, sport}), 1.0, 0.1); }