@@ -106,6 +106,63 @@ TEST(Interpreter, profileQuantizationForANetwork) {
106
106
EXPECT_NEAR (1.6 , max, 0.00001 );
107
107
}
108
108
109
+ // / Test that the symbol category for a symbol is properly set.
110
+ TEST (RuntimeBundle, BundleSymbolInfo) {
111
+ Module mod;
112
+ ExecutionEngine EE;
113
+ PlaceholderBindings bindings;
114
+
115
+ Tensor inputs (ElemKind::FloatTy, {1 , 10 , 10 , 3 });
116
+ inputs.getHandle ().randomize (-2 , 2 , mod.getPRNG ());
117
+
118
+ // Create a simple graph that has placeholders, constants, activations, and a
119
+ // tensor_view.
120
+ Function *F = mod.createFunction (" main" );
121
+ auto *input =
122
+ mod.createPlaceholder (ElemKind::FloatTy, {1 , 10 , 10 , 3 }, " in" , false );
123
+
124
+ auto *ex = mod.createConstant (ElemKind::Int64ITy, {1 , 1 }, " exp" );
125
+
126
+ auto *FC = F->createFullyConnected (bindings, " FC" , input, 30 );
127
+ auto *RL = F->createRELU (" RL2" , FC);
128
+ auto *SM = F->createSoftMax (" sm" , RL, ex);
129
+ auto *S = F->createSave (" ret" , SM);
130
+ auto *qp = F->createQuantizationProfile (bindings, " qp" , input);
131
+
132
+ EE.compile (CompilationMode::Infer, F);
133
+ auto table = EE.getCompiledFunction ().getRuntimeBundle ().getSymbolTable ();
134
+ // Check that placeholders and constants are correctly labelled.
135
+ EXPECT_EQ (table.find (S->getName ())->second .symbolCategory ,
136
+ glow::runtime::SymbolCategory::Placeholder);
137
+ EXPECT_EQ (table.find (ex->getName ())->second .symbolCategory ,
138
+ glow::runtime::SymbolCategory::Constant);
139
+ // Check that activations are labelled correctly.
140
+ EXPECT_EQ (table.find (" fc_add_bias_res" )->second .symbolCategory ,
141
+ glow::runtime::SymbolCategory::Activation);
142
+ // Check that tensor views have the same label as their parent symbol. In this
143
+ // case same as "input".
144
+ EXPECT_EQ (table.find (" tensorview_reshape" )->second .symbolCategory ,
145
+ glow::runtime::SymbolCategory::PlaceholderTensorView);
146
+
147
+ // Check that placeholders and constants input/output flags are correctly set.
148
+ EXPECT_EQ (table.find (S->getName ())->second .input , false );
149
+ EXPECT_EQ (table.find (S->getName ())->second .output , true );
150
+ EXPECT_EQ (table.find (ex->getName ())->second .input , false );
151
+ EXPECT_EQ (table.find (ex->getName ())->second .output , false );
152
+ EXPECT_EQ (table.find (input->getName ())->second .input , true );
153
+ EXPECT_EQ (table.find (input->getName ())->second .output , false );
154
+ EXPECT_EQ (table.find (qp->getHistogramPlaceholder ()->getName ())->second .input ,
155
+ true );
156
+ EXPECT_EQ (table.find (qp->getHistogramPlaceholder ()->getName ())->second .output ,
157
+ true );
158
+ // Check that activations are labelled correctly.
159
+ EXPECT_EQ (table.find (" fc_add_bias_res" )->second .input , false );
160
+ EXPECT_EQ (table.find (" fc_add_bias_res" )->second .output , false );
161
+ // Check that tensor views are labelled correctly.
162
+ EXPECT_EQ (table.find (" tensorview_reshape" )->second .input , false );
163
+ EXPECT_EQ (table.find (" tensorview_reshape" )->second .output , false );
164
+ }
165
+
109
166
TEST_P (BackendTest, simpleInference) {
110
167
Tensor inputs (ElemKind::FloatTy, {1 , 32 , 32 , 3 });
111
168
PlaceholderBindings ctx;
@@ -247,43 +304,6 @@ TEST_P(BackendTest, BundleSharedConstant) {
247
304
EXPECT_TRUE (it2 != table2.end ());
248
305
}
249
306
250
- // / Test that the symbol category for a symbol is properly set.
251
- TEST_P (BackendTest, BundleSymbolCategory) {
252
- Module mod;
253
- PlaceholderBindings bindings;
254
-
255
- Tensor inputs (ElemKind::FloatTy, {1 , 10 , 10 , 3 });
256
- inputs.getHandle ().randomize (-2 , 2 , mod.getPRNG ());
257
-
258
- // Create a simple graph that has placeholders, constants, activations, and a
259
- // tensor_view.
260
- Function *F = mod.createFunction (" main" );
261
- auto *input =
262
- mod.createPlaceholder (ElemKind::FloatTy, {1 , 10 , 10 , 3 }, " in" , false );
263
-
264
- auto *ex = mod.createConstant (ElemKind::Int64ITy, {1 , 1 }, " exp" );
265
-
266
- auto *FC = F->createFullyConnected (bindings, " FC" , input, 30 );
267
- auto *RL = F->createRELU (" RL2" , FC);
268
- auto *SM = F->createSoftMax (" sm" , RL, ex);
269
- auto *S = F->createSave (" ret" , SM);
270
-
271
- EE_.compile (CompilationMode::Infer, F);
272
- auto table = EE_.getCompiledFunction ().getRuntimeBundle ().getSymbolTable ();
273
- // Check that placeholders and constants are correctly labelled.
274
- EXPECT_EQ (table.find (S->getName ())->second .symbolCategory ,
275
- glow::runtime::SymbolCategory::Placeholder);
276
- EXPECT_EQ (table.find (ex->getName ())->second .symbolCategory ,
277
- glow::runtime::SymbolCategory::Constant);
278
- // Check that activations are labelled correctly.
279
- EXPECT_EQ (table.find (" fc_add_bias_res" )->second .symbolCategory ,
280
- glow::runtime::SymbolCategory::Activation);
281
- // Check that tensor views have the same label as their parent symbol. In this
282
- // case same as "input".
283
- EXPECT_EQ (table.find (" tensorview_reshape" )->second .symbolCategory ,
284
- glow::runtime::SymbolCategory::PlaceholderTensorView);
285
- }
286
-
287
307
// / Test compiling a vector of functions completes without error.
288
308
TEST_P (BackendTest, compileVectorOfFunctions) {
289
309
Module mod;
0 commit comments