Skip to content

Commit f41c0e1

Browse files
author
Elias Ellison
committed
Add comments for adding shape function and linting
ghstack-source-id: c7a5def Pull Request resolved: #73570
1 parent cfd92f2 commit f41c0e1

6 files changed

+238
-19
lines changed

test/jit/test_symbolic_shape_analysis.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from torch.testing._internal.common_utils import make_tensor
1414
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
15+
from typing import List, Any
1516

1617
if __name__ == '__main__':
1718
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
@@ -498,3 +499,37 @@ def test_shape_function_includes(self):
498499
m2_shape = [20, 10]
499500
res = torch.jit._shapes.matmul(m1_shape, m2_shape)
500501
self.assertEqual(res, [10, 10])
502+
503+
def test_register_function_error_checking(self):
504+
# this will error before registering on global map, so
505+
# no issue in overwriting schema mappings
506+
@torch.jit.script
507+
def foo(x, y):
508+
return x + y
509+
510+
node = foo.graph.findNode("aten::add")
511+
512+
@torch.jit.script
513+
def wrong_input_types(x, y):
514+
x: List[int] = []
515+
return x
516+
with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"):
517+
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_input_types.graph)
518+
519+
@torch.jit.script
520+
def wrong_output_types(x: List[int], y: List[int]):
521+
x: List[Tensor] = []
522+
return x
523+
524+
with self.assertRaisesRegex(RuntimeError, "but got graph_type"):
525+
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_output_types.graph)
526+
527+
@torch.jit.script
528+
def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any):
529+
x: List[int] = []
530+
return x
531+
532+
with self.assertRaises(RuntimeError) as error:
533+
torch._C._jit_register_shape_compute_graph_for_node(node, too_many_inputs.graph)
534+
535+
self.assertTrue("fewer arguments than schema" in str(error.exception))

torch/csrc/jit/python/init.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,19 @@ void initJITBindings(PyObject* module) {
159159
}
160160
return shapeComputeGraphForSchema(n->schema());
161161
})
162+
// using Node* here instead of Schema because looking up the schema
163+
// and passing it in from Python will have a different pointer than the
164+
// schema that is globally used for caching
165+
.def(
166+
"_jit_register_shape_compute_graph_for_node",
167+
[](Node* n, std::shared_ptr<Graph>& graph) {
168+
if (n->maybeSchema()) {
169+
const FunctionSchema& schema = n->schema();
170+
RegisterShapeComputeGraphForSchema(schema, graph);
171+
} else {
172+
TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
173+
}
174+
})
162175
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
163176
.def(
164177
"_jit_pass_propagate_shapes_on_graph_and_build_compute",

torch/csrc/jit/runtime/shape_functions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def conv2d(
343343
344344
def batch_norm(
345345
input: List[int],
346-
weight: List[int],
346+
weight: Optional[List[int]],
347347
bias: Optional[List[int]],
348348
running_mean: Optional[List[int]],
349349
running_var: Optional[List[int]],

torch/csrc/jit/runtime/symbolic_shape_registry.cpp

Lines changed: 150 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
#include <c10/util/Exception.h>
12
#include <torch/csrc/jit/frontend/ir_emitter.h>
3+
#include <torch/csrc/jit/ir/ir_views.h>
24
#include <torch/csrc/jit/jit_log.h>
35
#include <torch/csrc/jit/passes/inliner.h>
6+
#include <torch/csrc/jit/runtime/graph_iterator.h>
47
#include <torch/csrc/jit/runtime/operator.h>
58
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
69
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
@@ -160,26 +163,130 @@ const at::optional<const FunctionSchema*> getInplaceVariant(
160163
return at::nullopt;
161164
}
162165

163-
void registerSchema(
164-
const FunctionSchema* schema_string,
165-
const std::string& shape_compute_function_name,
166-
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
167-
const CompilationUnit& module) {
168-
if (reused_functions.count(shape_compute_function_name)) {
169-
auto graph = reused_functions[shape_compute_function_name];
166+
TypePtr mapTensorToListOfInts(TypePtr type) {
167+
if (type->cast<TensorType>()) {
168+
return ListType::ofInts();
169+
}
170+
at::ArrayRef<TypePtr> contained = type->containedTypes();
171+
if (contained.empty()) {
172+
return type;
173+
}
174+
return type->withContained(
175+
fmap(type->containedTypes(), mapTensorToListOfInts));
176+
}
170177

171-
// allow extra unused arguments to map multiple functions to e.g. unary
178+
void checkForWhileLoop(
179+
const FunctionSchema* schema,
180+
std::shared_ptr<Graph> graph) {
181+
DepthFirstGraphNodeIterator graph_it(graph);
182+
for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
183+
if (node->kind() != prim::Loop) {
184+
continue;
185+
}
186+
LoopView loop(node);
187+
if (loop.loopType() != LoopView::For) {
188+
TORCH_WARN(
189+
"While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: ",
190+
*node,
191+
" for schema ",
192+
*schema);
193+
}
194+
}
195+
}
196+
197+
void checkInputReturnedAsOutput(
198+
const FunctionSchema* schema,
199+
const std::shared_ptr<Graph>& graph) {
200+
// Could use alias db here as well but would have to warn because it's
201+
// imprecise
202+
for (size_t i : c10::irange(graph->inputs().size())) {
203+
Value* input = graph->inputs().at(i);
204+
for (size_t j : c10::irange(graph->outputs().size())) {
205+
Value* output = graph->outputs().at(j);
206+
TORCH_CHECK(
207+
input != output,
208+
"For schema: ",
209+
*schema,
210+
" input index ",
211+
i,
212+
" is returned as output index ",
213+
j,
214+
". Shape functions must return new unaliased lists");
215+
}
216+
}
217+
}
218+
219+
void checkInputAndOutputTypes(
220+
const FunctionSchema* schema,
221+
const std::shared_ptr<Graph>& graph) {
222+
// allow extra unused arguments to map multiple functions to e.g. unary
223+
TORCH_CHECK(
224+
graph->inputs().size() <= schema->arguments().size(),
225+
"Shape function must have fewer arguments than schema. Got ",
226+
graph->inputs().size(),
227+
" graph arguments and ",
228+
schema->arguments().size(),
229+
" schema arguments of schema: ",
230+
*schema);
231+
232+
for (auto i : c10::irange(graph->inputs().size())) {
233+
auto inp_type = schema->arguments().at(i).type();
234+
auto mapped_type = mapTensorToListOfInts(inp_type);
235+
auto graph_type = graph->inputs().at(i)->type();
172236
TORCH_INTERNAL_ASSERT(
173-
graph->inputs().size() <= schema_string->arguments().size());
237+
mapped_type->isSubtypeOf(graph->inputs().at(i)->type()),
238+
"For schema type: ",
239+
inp_type->str(),
240+
" Expected supertype of ",
241+
mapped_type->str(),
242+
" but got graph_type ",
243+
graph_type->str(),
244+
" at index ",
245+
i,
246+
" of schema: ",
247+
*schema);
248+
}
174249

175-
cached_schema_to_graph[schema_string] = graph;
176-
return;
250+
TORCH_CHECK(
251+
graph->outputs().size() == schema->returns().size(),
252+
"Shape function equal number of outputs as schema. Got ",
253+
graph->outputs().size(),
254+
" graph outputs and ",
255+
schema->returns().size(),
256+
" schema returns of schema: ",
257+
*schema);
258+
259+
for (auto i : c10::irange(schema->returns().size())) {
260+
auto out_type = schema->returns().at(i).type();
261+
auto mapped_type = mapTensorToListOfInts(out_type);
262+
auto graph_type = graph->outputs().at(i)->type();
263+
TORCH_INTERNAL_ASSERT(
264+
mapped_type->isSubtypeOf(graph->outputs().at(i)->type()),
265+
"For schema type: ",
266+
out_type->str(),
267+
" Expected supertype of ",
268+
mapped_type->str(),
269+
" but got graph_type ",
270+
graph_type->str(),
271+
" at output index ",
272+
i,
273+
" of schema: ",
274+
*schema);
177275
}
276+
}
178277

179-
Function& shape_compute_function =
180-
module.get_function(shape_compute_function_name);
181-
std::shared_ptr<Graph> graph =
182-
toGraphFunction(shape_compute_function).graph();
278+
void checkShapeFunction(
279+
const FunctionSchema* schema,
280+
const std::shared_ptr<Graph>& graph) {
281+
checkInputAndOutputTypes(schema, graph);
282+
checkForWhileLoop(schema, graph);
283+
checkInputReturnedAsOutput(schema, graph);
284+
// TODO: other checks ? list ops which we don't symbolically optimize, etc ?
285+
}
286+
287+
void transformShapeFunction(
288+
const FunctionSchema* schema_string,
289+
std::shared_ptr<Graph> graph) {
183290
Inline(*graph);
184291

185292
// ATEN operators can return multiple unboxed values, this in contrast to
@@ -197,9 +304,31 @@ void registerSchema(
197304
graph->registerOutput(v);
198305
}
199306
}
200-
// allow extra unused arguments to map multiple functions to e.g. unary
201-
TORCH_INTERNAL_ASSERT(
202-
graph->inputs().size() <= schema_string->arguments().size());
307+
}
308+
309+
void registerSchema(
310+
const FunctionSchema* schema_string,
311+
const std::string& shape_compute_function_name,
312+
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
313+
const CompilationUnit& module) {
314+
if (reused_functions.count(shape_compute_function_name)) {
315+
auto graph = reused_functions[shape_compute_function_name];
316+
317+
// allow extra unused arguments to map multiple functions to e.g. unary
318+
TORCH_INTERNAL_ASSERT(
319+
graph->inputs().size() <= schema_string->arguments().size());
320+
321+
cached_schema_to_graph[schema_string] = graph;
322+
return;
323+
}
324+
325+
Function& shape_compute_function =
326+
module.get_function(shape_compute_function_name);
327+
std::shared_ptr<Graph> graph =
328+
toGraphFunction(shape_compute_function).graph();
329+
330+
transformShapeFunction(schema_string, graph);
331+
checkShapeFunction(schema_string, graph);
203332

204333
cached_schema_to_graph[schema_string] = graph;
205334
reused_functions[shape_compute_function_name] = graph;
@@ -300,6 +429,9 @@ void RegisterShapeComputeGraphForSchema(
300429
if (cached_schema_to_graph.size() == 0) {
301430
loadFunctions();
302431
}
432+
transformShapeFunction(&schema, g);
433+
checkShapeFunction(&schema, g);
434+
303435
cached_schema_to_graph[&schema] = g;
304436
}
305437

torch/csrc/jit/runtime/symbolic_shape_registry.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,44 @@
88
namespace torch {
99
namespace jit {
1010

11+
/*
12+
ADDING A NEW SHAPE GRAPH:
13+
- For one node schema, there is one corresponding registered shape compute
14+
graph. The schema of the graph should be the same except for Tensor arguments.
15+
For every Tensor input in operator schema, there should be a List[int]
16+
corresponding to that Tensor's shape. For example: "aten::linear(Tensor input,
17+
Tensor weight, Tensor? bias=None) -> Tensor" ==> def linear(input: List[int],
18+
weight: List[int], bias: Optional[List[int]])
19+
20+
Additionally, arguments which are unused at the end of the schema may be left
21+
off. This allows sharing a single graph for multiple function schemas, such as
22+
unary operators with different trailing arguments that do not affect the output
23+
shape.
24+
25+
The shape graph should return a new, unaliased List[int] (or tuple of lists for
26+
multiple returns) and should not modify any input lists. This allows the shape
27+
graphs to be composed and executed.
28+
29+
The shape analysis (particularly for non-complete, or symbolic shapes) works by
30+
partially evaluating the JIT IR. It may be possible for a Graph to be registered
31+
that we cannot currently partially evaluate. If this happens, please file an
32+
issue. There are lints registered to avoid particular known patterns (continue
33+
or break or early return in a loop). Those may be improved in the future, please
34+
file an issue if necessary.
35+
36+
To debug (and write initially) the recommended flow is to define these functions
37+
in python and iterate there. Functions in `shape_functions.h` and
38+
`shape_functions_1.h` should be executable in python.
39+
40+
To test operators, the preferred flow is through OpInfos, with
41+
`assert_jit_shape_analysis=True`. If this is not feasible, you can look at tests
42+
in `test_symbolic_shape_analysis.py` such as `test_adaptive_avg_pool2d`.
43+
44+
Operators which take in a list of tensors, such as concat, are not yet
45+
supported. Concat has been special cased and could be generalized as needed.
46+
Please file an issue.
47+
*/
48+
1149
TORCH_API void RegisterShapeComputeGraphForSchema(
1250
const FunctionSchema& schema,
1351
std::shared_ptr<Graph> g);

torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
118118
{"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
119119
// TODO: enable slice, shape inference is not implemented for this op yet
120120
};
121+
// clang-format on
121122
return tensorexpr_elementwise_set;
122123
}
123124

0 commit comments

Comments
 (0)