Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit 48adbe7

Browse files
committed
[Placeholder] Add a test that uses placeholders instead of variables.
1 parent 070f1e2 commit 48adbe7

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/unittests/MLTest.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,42 @@ class TestRunnerBase : public ::testing::TestWithParam<BackendKind> {
3636
class InterpreterAndCPU : public TestRunnerBase {};
3737
class MLTest : public TestRunnerBase {};
3838

39+
/// Use placeholders (and not variables) to learn the square root of two.
40+
TEST_P(MLTest, learnSqrt2Placeholder) {
41+
TrainingConfig TC;
42+
Context ctx;
43+
44+
TC.learningRate = 0.03;
45+
46+
auto &mod = EE_.getModule();
47+
Function *F = mod.createFunction("Square root of 2");
48+
49+
auto *A = mod.createPlaceholder(ElemKind::FloatTy, {1}, "A", true);
50+
auto *inputTensor = ctx.allocate(A);
51+
inputTensor->init(Tensor::InitKind::Broadcast, 1, mod.getPRNG());
52+
53+
auto *E = mod.createPlaceholder(ElemKind::FloatTy, {1}, "Ex", false);
54+
ctx.allocate(E)->getHandle() = {2};
55+
56+
auto *O = mod.createPlaceholder(ElemKind::FloatTy, {1}, "output", false);
57+
ctx.allocate(O);
58+
59+
Node *M = F->createMul("Mult", A, A);
60+
M = F->createRegression("reg", M, E);
61+
F->createSave("ret", M);
62+
63+
Function *TF = glow::differentiate(F, TC);
64+
EE_.compile(CompilationMode::Train, TF, ctx);
65+
66+
// Train the network:
67+
for (int i = 0; i < 100; i++) {
68+
EE_.run();
69+
}
70+
71+
float res = inputTensor->getHandle().at({0});
72+
EXPECT_NEAR(res, 1.4142, 0.01);
73+
}
74+
3975
TEST_P(MLTest, trainASimpleNetwork) {
4076
TrainingConfig TC;
4177
Context ctx;

0 commit comments

Comments
 (0)