@@ -36,6 +36,42 @@ class TestRunnerBase : public ::testing::TestWithParam<BackendKind> {
36
36
class InterpreterAndCPU : public TestRunnerBase {};
37
37
class MLTest : public TestRunnerBase {};
38
38
39
+ // / Use placeholders (and not variables) to learn the square root of two.
40
+ TEST_P (MLTest, learnSqrt2Placeholder) {
41
+ TrainingConfig TC;
42
+ Context ctx;
43
+
44
+ TC.learningRate = 0.03 ;
45
+
46
+ auto &mod = EE_.getModule ();
47
+ Function *F = mod.createFunction (" Square root of 2" );
48
+
49
+ auto *A = mod.createPlaceholder (ElemKind::FloatTy, {1 }, " A" , true );
50
+ auto *inputTensor = ctx.allocate (A);
51
+ inputTensor->init (Tensor::InitKind::Broadcast, 1 , mod.getPRNG ());
52
+
53
+ auto *E = mod.createPlaceholder (ElemKind::FloatTy, {1 }, " Ex" , false );
54
+ ctx.allocate (E)->getHandle () = {2 };
55
+
56
+ auto *O = mod.createPlaceholder (ElemKind::FloatTy, {1 }, " output" , false );
57
+ ctx.allocate (O);
58
+
59
+ Node *M = F->createMul (" Mult" , A, A);
60
+ M = F->createRegression (" reg" , M, E);
61
+ F->createSave (" ret" , M);
62
+
63
+ Function *TF = glow::differentiate (F, TC);
64
+ EE_.compile (CompilationMode::Train, TF, ctx);
65
+
66
+ // Train the network:
67
+ for (int i = 0 ; i < 100 ; i++) {
68
+ EE_.run ();
69
+ }
70
+
71
+ float res = inputTensor->getHandle ().at ({0 });
72
+ EXPECT_NEAR (res, 1.4142 , 0.01 );
73
+ }
74
+
39
75
TEST_P (MLTest, trainASimpleNetwork) {
40
76
TrainingConfig TC;
41
77
Context ctx;
0 commit comments