1
+ #include < c10/util/Exception.h>
1
2
#include < torch/csrc/jit/frontend/ir_emitter.h>
3
+ #include < torch/csrc/jit/ir/ir_views.h>
2
4
#include < torch/csrc/jit/jit_log.h>
3
5
#include < torch/csrc/jit/passes/inliner.h>
6
+ #include < torch/csrc/jit/runtime/graph_iterator.h>
4
7
#include < torch/csrc/jit/runtime/operator.h>
5
8
#include < torch/csrc/jit/runtime/symbolic_shape_registry.h>
6
9
#include < torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
@@ -160,26 +163,121 @@ const at::optional<const FunctionSchema*> getInplaceVariant(
160
163
return at::nullopt;
161
164
}
162
165
163
- void registerSchema (
164
- const FunctionSchema* schema_string,
165
- const std::string& shape_compute_function_name,
166
- std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
167
- const CompilationUnit& module) {
168
- if (reused_functions.count (shape_compute_function_name)) {
169
- auto graph = reused_functions[shape_compute_function_name];
166
+ TypePtr mapTensorToListOfInts (TypePtr type) {
167
+ if (type->cast <TensorType>()) {
168
+ return ListType::ofInts ();
169
+ }
170
+ at::ArrayRef<TypePtr> contained = type->containedTypes ();
171
+ if (contained.empty ()) {
172
+ return type;
173
+ }
174
+ return type->withContained (
175
+ fmap (type->containedTypes (), mapTensorToListOfInts));
176
+ }
170
177
171
- // allow extra unused arguments to map multiple functions to e.g. unary
178
+ void checkForWhileLoop (
179
+ const FunctionSchema* schema,
180
+ std::shared_ptr<Graph> graph) {
181
+ DepthFirstGraphNodeIterator graph_it (graph);
182
+ for (auto * node = graph_it.next (); node != nullptr ; node = graph_it.next ()) {
183
+ if (node->kind () != prim::Loop) {
184
+ continue ;
185
+ }
186
+ LoopView loop (node);
187
+ if (loop.loopType () != LoopView::For) {
188
+ TORCH_WARN (
189
+ " While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: " ,
190
+ *node,
191
+ " for schema " ,
192
+ *schema);
193
+ }
194
+ }
195
+ }
196
+
197
+ void checkInputReturnedAsOutput (
198
+ const FunctionSchema* schema,
199
+ const std::shared_ptr<Graph>& graph) {
200
+ // Could use alias db here as well but would have to warn because it's
201
+ // imprecise
202
+ for (size_t i : c10::irange (graph->inputs ().size ())) {
203
+ Value* input = graph->inputs ().at (i);
204
+ for (size_t j : c10::irange (graph->outputs ().size ())) {
205
+ Value* output = graph->outputs ().at (j);
206
+ TORCH_CHECK (
207
+ input != output,
208
+ " For schema: " ,
209
+ *schema,
210
+ " input index " ,
211
+ i,
212
+ " is returned as output index " ,
213
+ j,
214
+ " . Shape functions must return new unaliased lists" );
215
+ }
216
+ }
217
+ }
218
+
219
+ void checkInputAndOutputTypes (
220
+ const FunctionSchema* schema,
221
+ const std::shared_ptr<Graph>& graph) {
222
+ // allow extra unused arguments to map multiple functions to e.g. unary
223
+ TORCH_CHECK (
224
+ graph->inputs ().size () <= schema->arguments ().size (),
225
+ " Shape function must have fewer arguments than schema. Got " ,
226
+ graph->inputs ().size (),
227
+ " graph arguments and " ,
228
+ schema->arguments ().size (),
229
+ " schema arguments of schema: " ,
230
+ *schema);
231
+
232
+ for (auto i : c10::irange (graph->inputs ().size ())) {
233
+ auto inp_type = schema->arguments ().at (i).type ();
234
+ auto mapped_type = mapTensorToListOfInts (inp_type);
235
+ auto graph_type = graph->inputs ().at (i)->type ();
172
236
TORCH_INTERNAL_ASSERT (
173
- graph->inputs ().size () <= schema_string->arguments ().size ());
237
+ mapped_type->isSubtypeOf (graph->inputs ().at (i)->type ()),
238
+ " For schema type: " ,
239
+ inp_type->str (),
240
+ " Expected supertype of " ,
241
+ mapped_type->str (),
242
+ " but got graph_type " ,
243
+ graph_type->str (),
244
+ " at index " ,
245
+ i,
246
+ " of schema: " ,
247
+ *schema);
248
+ }
174
249
175
- cached_schema_to_graph[schema_string] = graph;
176
- return ;
250
+ TORCH_CHECK (
251
+ graph->outputs ().size () == schema->returns ().size (),
252
+ " Shape function equal number of outputs as schema. Got " ,
253
+ graph->outputs ().size (),
254
+ " graph outputs and " ,
255
+ schema->returns ().size (),
256
+ " schema returns of schema: " ,
257
+ *schema);
258
+
259
+ for (auto i : c10::irange (schema->returns ().size ())) {
260
+ auto out_type = schema->returns ().at (i).type ();
261
+ auto mapped_type = mapTensorToListOfInts (out_type);
262
+ auto graph_type = graph->outputs ().at (i)->type ();
263
+ TORCH_INTERNAL_ASSERT (
264
+ mapped_type->isSubtypeOf (graph->outputs ().at (i)->type ()),
265
+ " For schema type: " ,
266
+ out_type->str (),
267
+ " Expected supertype of " ,
268
+ mapped_type->str (),
269
+ " but got graph_type " ,
270
+ graph_type->str (),
271
+ " at output index " ,
272
+ i,
273
+ " of schema: " ,
274
+ *schema);
177
275
}
276
+ }
178
277
179
- Function& shape_compute_function =
180
- module.get_function (shape_compute_function_name);
181
- std::shared_ptr<Graph> graph =
182
- toGraphFunction (shape_compute_function).graph ();
278
+ void transformShapeFunction (
279
+ const FunctionSchema* schema_string,
280
+ std::shared_ptr<Graph> graph) {
183
281
Inline (*graph);
184
282
185
283
// ATEN operators can return multiple unboxed values, this in contrast to
@@ -197,9 +295,33 @@ void registerSchema(
197
295
graph->registerOutput (v);
198
296
}
199
297
}
200
- // allow extra unused arguments to map multiple functions to e.g. unary
201
- TORCH_INTERNAL_ASSERT (
202
- graph->inputs ().size () <= schema_string->arguments ().size ());
298
+ }
299
+
300
+ void registerSchema (
301
+ const FunctionSchema* schema_string,
302
+ const std::string& shape_compute_function_name,
303
+ std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
304
+ const CompilationUnit& module) {
305
+ if (reused_functions.count (shape_compute_function_name)) {
306
+ auto graph = reused_functions[shape_compute_function_name];
307
+
308
+ // allow extra unused arguments to map multiple functions to e.g. unary
309
+ TORCH_INTERNAL_ASSERT (
310
+ graph->inputs ().size () <= schema_string->arguments ().size ());
311
+
312
+ cached_schema_to_graph[schema_string] = graph;
313
+ return ;
314
+ }
315
+
316
+ Function& shape_compute_function =
317
+ module.get_function (shape_compute_function_name);
318
+ std::shared_ptr<Graph> graph =
319
+ toGraphFunction (shape_compute_function).graph ();
320
+
321
+ transformShapeFunction (schema_string, graph);
322
+ // NB: we lint the shape functions registered in source
323
+ // in a test file
324
+ // LintShapeComputeGraph(schema_string, graph);
203
325
204
326
cached_schema_to_graph[schema_string] = graph;
205
327
reused_functions[shape_compute_function_name] = graph;
@@ -299,8 +421,34 @@ void RegisterShapeComputeGraphForSchema(
299
421
if (cached_schema_to_graph.size () == 0 ) {
300
422
loadFunctions ();
301
423
}
424
+ transformShapeFunction (&schema, g);
425
+ LintShapeComputeGraph (&schema, g);
426
+
302
427
cached_schema_to_graph[&schema] = g;
303
428
}
304
429
430
+ std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas () {
431
+ std::lock_guard<std::mutex> guard (lock);
432
+ if (cached_schema_to_graph.size () == 0 ) {
433
+ loadFunctions ();
434
+ }
435
+
436
+ std::vector<const FunctionSchema*> schemas;
437
+ schemas.reserve (cached_schema_to_graph.size ());
438
+ for (const auto & pair : cached_schema_to_graph) {
439
+ schemas.push_back (pair.first );
440
+ }
441
+ return schemas;
442
+ }
443
+
444
+ void LintShapeComputeGraph (
445
+ const FunctionSchema* schema,
446
+ const std::shared_ptr<Graph>& graph) {
447
+ checkInputAndOutputTypes (schema, graph);
448
+ checkForWhileLoop (schema, graph);
449
+ checkInputReturnedAsOutput (schema, graph);
450
+ // TODO: other checks ? list ops which we don't symbolically optimize, etc ?
451
+ }
452
+
305
453
} // namespace jit
306
454
} // namespace torch
0 commit comments