Skip to content

Commit f7c94dd

Browse files
author
Elias Ellison
committed
Add comments for adding shape function and linting
ghstack-source-id: 71302d6 Pull Request resolved: #73570
1 parent 92f01e1 commit f7c94dd

7 files changed

+270
-20
lines changed

test/cpp/jit/test_misc.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,6 +2973,15 @@ TEST(TestFunctionExecutor, RunDecompositionTest) {
29732973
}
29742974
}
29752975

2976+
TEST(TestShapeGraphLinting, Basic) {
2977+
auto schemas = RegisteredShapeComputeSchemas();
2978+
for (const auto& schema : schemas) {
2979+
auto g = shapeComputeGraphForSchema(*schema);
2980+
TORCH_INTERNAL_ASSERT(g);
2981+
LintShapeComputeGraph(schema, *g);
2982+
}
2983+
}
2984+
29762985
// TODO: move to test_kernel when global settings are explicit
29772986
// fusion parameters
29782987
class Composed : public ::testing::Test {

test/jit/test_symbolic_shape_analysis.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from torch.testing._internal.common_utils import make_tensor
1414
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
15+
from typing import List, Any
1516

1617
if __name__ == '__main__':
1718
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
@@ -498,3 +499,37 @@ def test_shape_function_includes(self):
498499
m2_shape = [20, 10]
499500
res = torch.jit._shapes.matmul(m1_shape, m2_shape)
500501
self.assertEqual(res, [10, 10])
502+
503+
def test_register_function_error_checking(self):
504+
# this will error before registering on global map, so
505+
# no issue in overwriting schema mappings
506+
@torch.jit.script
507+
def foo(x, y):
508+
return x + y
509+
510+
node = foo.graph.findNode("aten::add")
511+
512+
@torch.jit.script
513+
def wrong_input_types(x, y):
514+
x: List[int] = []
515+
return x
516+
with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"):
517+
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_input_types.graph)
518+
519+
@torch.jit.script
520+
def wrong_output_types(x: List[int], y: List[int]):
521+
x: List[Tensor] = []
522+
return x
523+
524+
with self.assertRaisesRegex(RuntimeError, "but got graph_type"):
525+
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_output_types.graph)
526+
527+
@torch.jit.script
528+
def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any):
529+
x: List[int] = []
530+
return x
531+
532+
with self.assertRaises(RuntimeError) as error:
533+
torch._C._jit_register_shape_compute_graph_for_node(node, too_many_inputs.graph)
534+
535+
self.assertTrue("fewer arguments than schema" in str(error.exception))

torch/csrc/jit/python/init.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,19 @@ void initJITBindings(PyObject* module) {
170170
return DecompositionGraphForSchema(n->schema());
171171
})
172172
.def("_jit_pass_run_decompositions", RunDecompositions)
173+
// using Node* here instead of Schema because looking up the schema
174+
// and passing it in from Python will have a different pointer than the
175+
// schema that is globally used for caching
176+
.def(
177+
"_jit_register_shape_compute_graph_for_node",
178+
[](Node* n, std::shared_ptr<Graph>& graph) {
179+
if (n->maybeSchema()) {
180+
const FunctionSchema& schema = n->schema();
181+
RegisterShapeComputeGraphForSchema(schema, graph);
182+
} else {
183+
TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
184+
}
185+
})
173186
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
174187
.def(
175188
"_jit_pass_propagate_shapes_on_graph_and_build_compute",

torch/csrc/jit/runtime/shape_functions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def conv2d(
343343
344344
def batch_norm(
345345
input: List[int],
346-
weight: List[int],
346+
weight: Optional[List[int]],
347347
bias: Optional[List[int]],
348348
running_mean: Optional[List[int]],
349349
running_var: Optional[List[int]],

torch/csrc/jit/runtime/symbolic_shape_registry.cpp

Lines changed: 166 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
#include <c10/util/Exception.h>
12
#include <torch/csrc/jit/frontend/ir_emitter.h>
3+
#include <torch/csrc/jit/ir/ir_views.h>
24
#include <torch/csrc/jit/jit_log.h>
35
#include <torch/csrc/jit/passes/inliner.h>
6+
#include <torch/csrc/jit/runtime/graph_iterator.h>
47
#include <torch/csrc/jit/runtime/operator.h>
58
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
69
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
@@ -160,26 +163,121 @@ const at::optional<const FunctionSchema*> getInplaceVariant(
160163
return at::nullopt;
161164
}
162165

163-
void registerSchema(
164-
const FunctionSchema* schema_string,
165-
const std::string& shape_compute_function_name,
166-
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
167-
const CompilationUnit& module) {
168-
if (reused_functions.count(shape_compute_function_name)) {
169-
auto graph = reused_functions[shape_compute_function_name];
166+
TypePtr mapTensorToListOfInts(TypePtr type) {
167+
if (type->cast<TensorType>()) {
168+
return ListType::ofInts();
169+
}
170+
at::ArrayRef<TypePtr> contained = type->containedTypes();
171+
if (contained.empty()) {
172+
return type;
173+
}
174+
return type->withContained(
175+
fmap(type->containedTypes(), mapTensorToListOfInts));
176+
}
170177

171-
// allow extra unused arguments to map multiple functions to e.g. unary
178+
void checkForWhileLoop(
179+
const FunctionSchema* schema,
180+
std::shared_ptr<Graph> graph) {
181+
DepthFirstGraphNodeIterator graph_it(graph);
182+
for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
183+
if (node->kind() != prim::Loop) {
184+
continue;
185+
}
186+
LoopView loop(node);
187+
if (loop.loopType() != LoopView::For) {
188+
TORCH_WARN(
189+
"While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: ",
190+
*node,
191+
" for schema ",
192+
*schema);
193+
}
194+
}
195+
}
196+
197+
void checkInputReturnedAsOutput(
198+
const FunctionSchema* schema,
199+
const std::shared_ptr<Graph>& graph) {
200+
// Could use alias db here as well but would have to warn because it's
201+
// imprecise
202+
for (size_t i : c10::irange(graph->inputs().size())) {
203+
Value* input = graph->inputs().at(i);
204+
for (size_t j : c10::irange(graph->outputs().size())) {
205+
Value* output = graph->outputs().at(j);
206+
TORCH_CHECK(
207+
input != output,
208+
"For schema: ",
209+
*schema,
210+
" input index ",
211+
i,
212+
" is returned as output index ",
213+
j,
214+
". Shape functions must return new unaliased lists");
215+
}
216+
}
217+
}
218+
219+
void checkInputAndOutputTypes(
220+
const FunctionSchema* schema,
221+
const std::shared_ptr<Graph>& graph) {
222+
// allow extra unused arguments to map multiple functions to e.g. unary
223+
TORCH_CHECK(
224+
graph->inputs().size() <= schema->arguments().size(),
225+
"Shape function must have fewer arguments than schema. Got ",
226+
graph->inputs().size(),
227+
" graph arguments and ",
228+
schema->arguments().size(),
229+
" schema arguments of schema: ",
230+
*schema);
231+
232+
for (auto i : c10::irange(graph->inputs().size())) {
233+
auto inp_type = schema->arguments().at(i).type();
234+
auto mapped_type = mapTensorToListOfInts(inp_type);
235+
auto graph_type = graph->inputs().at(i)->type();
172236
TORCH_INTERNAL_ASSERT(
173-
graph->inputs().size() <= schema_string->arguments().size());
237+
mapped_type->isSubtypeOf(graph->inputs().at(i)->type()),
238+
"For schema type: ",
239+
inp_type->str(),
240+
" Expected supertype of ",
241+
mapped_type->str(),
242+
" but got graph_type ",
243+
graph_type->str(),
244+
" at index ",
245+
i,
246+
" of schema: ",
247+
*schema);
248+
}
174249

175-
cached_schema_to_graph[schema_string] = graph;
176-
return;
250+
TORCH_CHECK(
251+
graph->outputs().size() == schema->returns().size(),
252+
"Shape function equal number of outputs as schema. Got ",
253+
graph->outputs().size(),
254+
" graph outputs and ",
255+
schema->returns().size(),
256+
" schema returns of schema: ",
257+
*schema);
258+
259+
for (auto i : c10::irange(schema->returns().size())) {
260+
auto out_type = schema->returns().at(i).type();
261+
auto mapped_type = mapTensorToListOfInts(out_type);
262+
auto graph_type = graph->outputs().at(i)->type();
263+
TORCH_INTERNAL_ASSERT(
264+
mapped_type->isSubtypeOf(graph->outputs().at(i)->type()),
265+
"For schema type: ",
266+
out_type->str(),
267+
" Expected supertype of ",
268+
mapped_type->str(),
269+
" but got graph_type ",
270+
graph_type->str(),
271+
" at output index ",
272+
i,
273+
" of schema: ",
274+
*schema);
177275
}
276+
}
178277

179-
Function& shape_compute_function =
180-
module.get_function(shape_compute_function_name);
181-
std::shared_ptr<Graph> graph =
182-
toGraphFunction(shape_compute_function).graph();
278+
void transformShapeFunction(
279+
const FunctionSchema* schema_string,
280+
std::shared_ptr<Graph> graph) {
183281
Inline(*graph);
184282

185283
// ATEN operators can return multiple unboxed values, this in contrast to
@@ -197,9 +295,33 @@ void registerSchema(
197295
graph->registerOutput(v);
198296
}
199297
}
200-
// allow extra unused arguments to map multiple functions to e.g. unary
201-
TORCH_INTERNAL_ASSERT(
202-
graph->inputs().size() <= schema_string->arguments().size());
298+
}
299+
300+
void registerSchema(
301+
const FunctionSchema* schema_string,
302+
const std::string& shape_compute_function_name,
303+
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
304+
const CompilationUnit& module) {
305+
if (reused_functions.count(shape_compute_function_name)) {
306+
auto graph = reused_functions[shape_compute_function_name];
307+
308+
// allow extra unused arguments to map multiple functions to e.g. unary
309+
TORCH_INTERNAL_ASSERT(
310+
graph->inputs().size() <= schema_string->arguments().size());
311+
312+
cached_schema_to_graph[schema_string] = graph;
313+
return;
314+
}
315+
316+
Function& shape_compute_function =
317+
module.get_function(shape_compute_function_name);
318+
std::shared_ptr<Graph> graph =
319+
toGraphFunction(shape_compute_function).graph();
320+
321+
transformShapeFunction(schema_string, graph);
322+
// NB: we lint the shape functions registered in source
323+
// in a test file
324+
// LintShapeComputeGraph(schema_string, graph);
203325

204326
cached_schema_to_graph[schema_string] = graph;
205327
reused_functions[shape_compute_function_name] = graph;
@@ -299,8 +421,34 @@ void RegisterShapeComputeGraphForSchema(
299421
if (cached_schema_to_graph.size() == 0) {
300422
loadFunctions();
301423
}
424+
transformShapeFunction(&schema, g);
425+
LintShapeComputeGraph(&schema, g);
426+
302427
cached_schema_to_graph[&schema] = g;
303428
}
304429

430+
std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas() {
431+
std::lock_guard<std::mutex> guard(lock);
432+
if (cached_schema_to_graph.size() == 0) {
433+
loadFunctions();
434+
}
435+
436+
std::vector<const FunctionSchema*> schemas;
437+
schemas.reserve(cached_schema_to_graph.size());
438+
for (const auto& pair : cached_schema_to_graph) {
439+
schemas.push_back(pair.first);
440+
}
441+
return schemas;
442+
}
443+
444+
void LintShapeComputeGraph(
445+
const FunctionSchema* schema,
446+
const std::shared_ptr<Graph>& graph) {
447+
checkInputAndOutputTypes(schema, graph);
448+
checkForWhileLoop(schema, graph);
449+
checkInputReturnedAsOutput(schema, graph);
450+
// TODO: other checks ? list ops which we don't symbolically optimize, etc ?
451+
}
452+
305453
} // namespace jit
306454
} // namespace torch

torch/csrc/jit/runtime/symbolic_shape_registry.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,56 @@
88
namespace torch {
99
namespace jit {
1010

11+
/*
12+
ADDING A NEW SHAPE GRAPH:
13+
- For one node schema, there is one corresponding registered shape compute
14+
graph. The schema of the graph should be the same except for Tensor arguments.
15+
For every Tensor input in operator schema, there should be a List[int]
16+
corresponding to that Tensor's shape. For example: "aten::linear(Tensor input,
17+
Tensor weight, Tensor? bias=None) -> Tensor" ==> def linear(input: List[int],
18+
weight: List[int], bias: Optional[List[int]])
19+
20+
Additionally, arguments which are unused at the end of the schema may be left
21+
off. This allows sharing a single graph for multiple function schemas, such as
22+
unary operators with different trailing arguments that do not affect the output
23+
shape.
24+
25+
The shape graph should return a new, unaliased List[int] (or tuple of lists for
26+
multiple returns) and should not modify any input lists. This allows the shape
27+
graphs to be composed and executed.
28+
29+
The shape analysis (particularly for non-complete, or symbolic shapes) works by
30+
partially evaluating the JIT IR. It may be possible for a Graph to be registered
31+
that we cannot currently partially evaluate. If this happens, please file an
32+
issue. There are lints registered to avoid particular known patterns (continue
33+
or break or early return in a loop). Those may be improved in the future, please
34+
file an issue if necessary.
35+
36+
To debug (and write initially) the recommended flow is to define these functions
37+
in python and iterate there. Functions in `shape_functions.h` and
38+
`shape_functions_1.h` should be executable in python.
39+
40+
To test operators, the preferred flow is through OpInfos, with
41+
`assert_jit_shape_analysis=True`. If this is not feasible, you can look at tests
42+
in `test_symbolic_shape_analysis.py` such as `test_adaptive_avg_pool2d`.
43+
44+
Operators which take in a list of tensors, such as concat, are not yet
45+
supported. Concat has been special cased and could be generalized as needed.
46+
Please file an issue.
47+
*/
48+
1149
TORCH_API void RegisterShapeComputeGraphForSchema(
1250
const FunctionSchema& schema,
1351
std::shared_ptr<Graph> g);
1452

1553
TORCH_API c10::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
1654
const FunctionSchema& schema);
1755

56+
TORCH_API std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas();
57+
58+
TORCH_API void LintShapeComputeGraph(
59+
const FunctionSchema* schema,
60+
const std::shared_ptr<Graph>& graph);
61+
1862
} // namespace jit
1963
} // namespace torch

torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
118118
{"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
119119
// TODO: enable slice, shape inference is not implemented for this op yet
120120
};
121-
return tensorexpr_elementwise_set;
121+
// clang-format on
122+
return tensorexpr_elementwise_set;
122123
}
123124

124125
}

0 commit comments

Comments
 (0)