1
+ #include < c10/util/Exception.h>
1
2
#include < torch/csrc/jit/frontend/ir_emitter.h>
3
+ #include < torch/csrc/jit/ir/ir_views.h>
2
4
#include < torch/csrc/jit/jit_log.h>
3
5
#include < torch/csrc/jit/passes/inliner.h>
6
+ #include < torch/csrc/jit/runtime/graph_iterator.h>
4
7
#include < torch/csrc/jit/runtime/operator.h>
5
8
#include < torch/csrc/jit/runtime/symbolic_shape_registry.h>
6
9
#include < torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
@@ -160,26 +163,130 @@ const at::optional<const FunctionSchema*> getInplaceVariant(
160
163
return at::nullopt;
161
164
}
162
165
163
- void registerSchema (
164
- const FunctionSchema* schema_string,
165
- const std::string& shape_compute_function_name,
166
- std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
167
- const CompilationUnit& module) {
168
- if (reused_functions.count (shape_compute_function_name)) {
169
- auto graph = reused_functions[shape_compute_function_name];
166
+ TypePtr mapTensorToListOfInts (TypePtr type) {
167
+ if (type->cast <TensorType>()) {
168
+ return ListType::ofInts ();
169
+ }
170
+ at::ArrayRef<TypePtr> contained = type->containedTypes ();
171
+ if (contained.empty ()) {
172
+ return type;
173
+ }
174
+ return type->withContained (
175
+ fmap (type->containedTypes (), mapTensorToListOfInts));
176
+ }
170
177
171
- // allow extra unused arguments to map multiple functions to e.g. unary
178
+ void checkForWhileLoop (
179
+ const FunctionSchema* schema,
180
+ std::shared_ptr<Graph> graph) {
181
+ DepthFirstGraphNodeIterator graph_it (graph);
182
+ for (auto * node = graph_it.next (); node != nullptr ; node = graph_it.next ()) {
183
+ if (node->kind () != prim::Loop) {
184
+ continue ;
185
+ }
186
+ LoopView loop (node);
187
+ if (loop.loopType () != LoopView::For) {
188
+ TORCH_WARN (
189
+ " While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: " ,
190
+ *node,
191
+ " for schema " ,
192
+ *schema);
193
+ }
194
+ }
195
+ }
196
+
197
+ void checkInputReturnedAsOutput (
198
+ const FunctionSchema* schema,
199
+ const std::shared_ptr<Graph>& graph) {
200
+ // Could use alias db here as well but would have to warn because it's
201
+ // imprecise
202
+ for (size_t i : c10::irange (graph->inputs ().size ())) {
203
+ Value* input = graph->inputs ().at (i);
204
+ for (size_t j : c10::irange (graph->outputs ().size ())) {
205
+ Value* output = graph->outputs ().at (j);
206
+ TORCH_CHECK (
207
+ input != output,
208
+ " For schema: " ,
209
+ *schema,
210
+ " input index " ,
211
+ i,
212
+ " is returned as output index " ,
213
+ j,
214
+ " . Shape functions must return new unaliased lists" );
215
+ }
216
+ }
217
+ }
218
+
219
+ void checkInputAndOutputTypes (
220
+ const FunctionSchema* schema,
221
+ const std::shared_ptr<Graph>& graph) {
222
+ // allow extra unused arguments to map multiple functions to e.g. unary
223
+ TORCH_CHECK (
224
+ graph->inputs ().size () <= schema->arguments ().size (),
225
+ " Shape function must have fewer arguments than schema. Got " ,
226
+ graph->inputs ().size (),
227
+ " graph arguments and " ,
228
+ schema->arguments ().size (),
229
+ " schema arguments of schema: " ,
230
+ *schema);
231
+
232
+ for (auto i : c10::irange (graph->inputs ().size ())) {
233
+ auto inp_type = schema->arguments ().at (i).type ();
234
+ auto mapped_type = mapTensorToListOfInts (inp_type);
235
+ auto graph_type = graph->inputs ().at (i)->type ();
172
236
TORCH_INTERNAL_ASSERT (
173
- graph->inputs ().size () <= schema_string->arguments ().size ());
237
+ mapped_type->isSubtypeOf (graph->inputs ().at (i)->type ()),
238
+ " For schema type: " ,
239
+ inp_type->str (),
240
+ " Expected supertype of " ,
241
+ mapped_type->str (),
242
+ " but got graph_type " ,
243
+ graph_type->str (),
244
+ " at index " ,
245
+ i,
246
+ " of schema: " ,
247
+ *schema);
248
+ }
174
249
175
- cached_schema_to_graph[schema_string] = graph;
176
- return ;
250
+ TORCH_CHECK (
251
+ graph->outputs ().size () == schema->returns ().size (),
252
+ " Shape function equal number of outputs as schema. Got " ,
253
+ graph->outputs ().size (),
254
+ " graph outputs and " ,
255
+ schema->returns ().size (),
256
+ " schema returns of schema: " ,
257
+ *schema);
258
+
259
+ for (auto i : c10::irange (schema->returns ().size ())) {
260
+ auto out_type = schema->returns ().at (i).type ();
261
+ auto mapped_type = mapTensorToListOfInts (out_type);
262
+ auto graph_type = graph->outputs ().at (i)->type ();
263
+ TORCH_INTERNAL_ASSERT (
264
+ mapped_type->isSubtypeOf (graph->outputs ().at (i)->type ()),
265
+ " For schema type: " ,
266
+ out_type->str (),
267
+ " Expected supertype of " ,
268
+ mapped_type->str (),
269
+ " but got graph_type " ,
270
+ graph_type->str (),
271
+ " at output index " ,
272
+ i,
273
+ " of schema: " ,
274
+ *schema);
177
275
}
276
+ }
178
277
179
- Function& shape_compute_function =
180
- module.get_function (shape_compute_function_name);
181
- std::shared_ptr<Graph> graph =
182
- toGraphFunction (shape_compute_function).graph ();
278
+ void checkShapeFunction (
279
+ const FunctionSchema* schema,
280
+ const std::shared_ptr<Graph>& graph) {
281
+ checkInputAndOutputTypes (schema, graph);
282
+ checkForWhileLoop (schema, graph);
283
+ checkInputReturnedAsOutput (schema, graph);
284
+ // TODO: other checks ? list ops which we don't symbolically optimize, etc ?
285
+ }
286
+
287
+ void transformShapeFunction (
288
+ const FunctionSchema* schema_string,
289
+ std::shared_ptr<Graph> graph) {
183
290
Inline (*graph);
184
291
185
292
// ATEN operators can return multiple unboxed values, this in contrast to
@@ -197,9 +304,31 @@ void registerSchema(
197
304
graph->registerOutput (v);
198
305
}
199
306
}
200
- // allow extra unused arguments to map multiple functions to e.g. unary
201
- TORCH_INTERNAL_ASSERT (
202
- graph->inputs ().size () <= schema_string->arguments ().size ());
307
+ }
308
+
309
+ void registerSchema (
310
+ const FunctionSchema* schema_string,
311
+ const std::string& shape_compute_function_name,
312
+ std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
313
+ const CompilationUnit& module) {
314
+ if (reused_functions.count (shape_compute_function_name)) {
315
+ auto graph = reused_functions[shape_compute_function_name];
316
+
317
+ // allow extra unused arguments to map multiple functions to e.g. unary
318
+ TORCH_INTERNAL_ASSERT (
319
+ graph->inputs ().size () <= schema_string->arguments ().size ());
320
+
321
+ cached_schema_to_graph[schema_string] = graph;
322
+ return ;
323
+ }
324
+
325
+ Function& shape_compute_function =
326
+ module.get_function (shape_compute_function_name);
327
+ std::shared_ptr<Graph> graph =
328
+ toGraphFunction (shape_compute_function).graph ();
329
+
330
+ transformShapeFunction (schema_string, graph);
331
+ checkShapeFunction (schema_string, graph);
203
332
204
333
cached_schema_to_graph[schema_string] = graph;
205
334
reused_functions[shape_compute_function_name] = graph;
@@ -300,6 +429,9 @@ void RegisterShapeComputeGraphForSchema(
300
429
if (cached_schema_to_graph.size () == 0 ) {
301
430
loadFunctions ();
302
431
}
432
+ transformShapeFunction (&schema, g);
433
+ checkShapeFunction (&schema, g);
434
+
303
435
cached_schema_to_graph[&schema] = g;
304
436
}
305
437
0 commit comments