Skip to content

Commit 344bab2

Browse files
XuanQipytorchmergebot
authored andcommitted
[RFC]: Functionalize assertions (#103757)
The idea here is to create do a graph mutation to: * Create an initial dependency token at the beginning of the program. * Replace non-functional version of assertion statements to functional version. * The functional version of assertion statement will: * Accept a dependency token from output of previous functional assertion statement (or the initial dependency token if there isn't any). * Generate a dependency token as the output of assertion statement. * Augment the output to include the dependency token generated by last assertion statement. The goal here is to: * Form an explicit dependency chain and avoid potential reordering during other passes of compiling. * Make the assertions a part of overall execution graph will affect the final output (or it could potentially be DCEed). **NOTE:** * Currently only cover `contrain_range` and WIP to support other assertions. Send out this PR to collect feedback first. * Here it only focus on implementation itself. Will integrate it with current export in future PR. Pull Request resolved: #103757 Approved by: https://github.com/avikchaudhuri
1 parent 98d513c commit 344bab2

5 files changed

+203
-7
lines changed

test/export/test_passes.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33

44
import torch
55
from torch.testing._internal.common_utils import run_tests, TestCase
6+
from torch.testing import FileCheck
67
from torch._dynamo.eval_frame import is_dynamo_supported
78
from torch._export import export, dynamic_dim
8-
from torch._export.constraints import constrain_as_value
9+
from torch._export.constraints import constrain_as_value, constrain_as_size
910
from torch._export.passes import (
1011
ReplaceViewOpsWithViewCopyOpsPass,
1112
)
1213
from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import (
1314
is_view_op,
1415
get_view_copy_of_view_op,
1516
)
17+
from torch._export.passes.functionalize_side_effectful_ops_pass import (
18+
_FunctionalizeSideEffectfulOpsPass,
19+
)
1620
from functorch.experimental.control_flow import cond
1721

1822

@@ -335,6 +339,81 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
335339
real_result = m(x, y)
336340
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
337341

342+
def test_functionalize_inline_contraints(self) -> None:
343+
def f(x):
344+
a = x.item()
345+
constrain_as_size(a, 4, 7)
346+
return torch.empty((a, 4))
347+
348+
ep = torch._export.export(f, (torch.tensor([7]),))
349+
gm = ep.graph_module
350+
FileCheck().check_count(
351+
"torch.ops.aten.sym_constrain_range.default",
352+
1,
353+
exactly=True,
354+
).run(gm.code)
355+
356+
gm = ep.transform(_FunctionalizeSideEffectfulOpsPass()).graph_module
357+
358+
with self.assertRaisesRegex(
359+
RuntimeError,
360+
r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]",
361+
) as cm:
362+
gm(torch.tensor([20]))
363+
364+
inp = torch.tensor([5])
365+
res, dep_token = gm(inp)
366+
self.assertEqual(res.shape, torch.Size([5, 4]))
367+
self.assertEqual(dep_token.shape, torch.Size([]))
368+
369+
FileCheck().check_count(
370+
"torch.ops.aten._functional_sym_constrain_range", 1, exactly=True
371+
).run(gm.code)
372+
FileCheck().check_count(
373+
"torch.ops.aten.sym_constrain_range.default", 0, exactly=True
374+
).run(gm.code)
375+
376+
dep_token_node = next(n for n in gm.graph.nodes if n.name == "dep_token3")
377+
constrain_node = next(
378+
n
379+
for n in gm.graph.nodes
380+
if n.target == torch.ops.aten._functional_sym_constrain_range
381+
)
382+
self.assertEqual(constrain_node.kwargs["dep_token"], dep_token_node)
383+
384+
def test_functionalize_input_constraints(self) -> None:
385+
def f(x):
386+
return x * 2
387+
388+
inp = torch.zeros(4, 8)
389+
ep = torch._export.export(
390+
f,
391+
(inp,),
392+
constraints=[
393+
dynamic_dim(inp, 0) < 10,
394+
dynamic_dim(inp, 0) >= 3,
395+
],
396+
)
397+
FileCheck().check_count(
398+
"torch.ops.aten._assert_async.msg", 3, exactly=True
399+
).run(ep.graph_module.code)
400+
401+
gm = ep.transform(_FunctionalizeSideEffectfulOpsPass()).graph_module
402+
with self.assertRaisesRegex(
403+
RuntimeError,
404+
r"Input arg0_1.shape\[0\] is outside of specified dynamic range \[3, 9\]",
405+
):
406+
gm(torch.ones(11, 8))
407+
408+
inp = torch.ones(6, 8)
409+
self.assertEqual(gm(inp)[0], f(inp))
410+
FileCheck().check_count(
411+
"torch.ops.aten._functional_assert_async.msg", 3, exactly=True
412+
).run(gm.code)
413+
FileCheck().check_count(
414+
"torch.ops.aten._assert_async.msg", 0, exactly=True
415+
).run(gm.code)
416+
338417

339418
if __name__ == '__main__':
340419
run_tests()

torch/_export/pass_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
import traceback
23
import typing
34
from contextlib import nullcontext
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -38,6 +39,11 @@ class ExportPassBase(PassBase):
3839
transformations.
3940
"""
4041

42+
@staticmethod
43+
def _create_dummy_node_metadata():
44+
return NodeMetadata({"stack_trace": traceback.format_exc(-1)})
45+
46+
4147
class ExportTracer(PythonKeyTracer):
4248
"""
4349
Tracer used to create nodes during the retracing part of the ExportPassBase

torch/_export/passes/add_runtime_assertions_for_constraints_pass.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch.fx
1414
from torch.fx.experimental.symbolic_shapes import SymInt
1515
from torch._export.pass_base import ExportPassBase, ProxyValue, PassResult
16-
from torch._export.pass_infra.node_metadata import NodeMetadata
1716
from torch._subclasses.fake_tensor import FakeTensor
1817

1918

@@ -188,11 +187,6 @@ def _insert_equality_assert_inplace(
188187
assert_msg
189188
)
190189

191-
def _create_dummy_node_metadata(self):
192-
return NodeMetadata({
193-
"stack_trace": traceback.format_exc(-1)
194-
})
195-
196190
def _insert_assert_async_inplace(self, graph, operator, args, assert_msg):
197191
"""
198192
Inserts assert_async call_function nodes in the graph. This function is
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import copy
2+
from typing import Dict, Optional, Tuple, List
3+
4+
import torch
5+
from torch._export.pass_base import ExportPassBase, PassResult, Argument
6+
from torch._export.pass_infra.node_metadata import NodeMetadata
7+
from torch._export.pass_infra.proxy_value import ProxyValue
8+
from torch._ops import OpOverload
9+
10+
aten = torch.ops.aten
11+
12+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
13+
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
14+
aten._assert_async.msg: aten._functional_assert_async.msg,
15+
}
16+
17+
18+
class _FunctionalizeSideEffectfulOpsPass(ExportPassBase):
19+
"""
20+
Functionalize ops with side effect in graph module by replacing the op with
21+
functional version of it. A new dependency token (`dep_token`) will be
22+
created and propagated through functional ops to output.
23+
For example:
24+
```
25+
def f(x):
26+
sym_constrain_range(x.shape[0], min=1, max=3)
27+
return x.add(3)
28+
```
29+
Will be transformed to:
30+
```
31+
def f(x):
32+
dep_token0 = _make_dep_token()
33+
dep_token1 = _functional_sym_constrain_range(
34+
x.shape[0], min=1, max=3, dep_token=dep_token0
35+
)
36+
37+
return x.add(3), dep_token1
38+
```
39+
"""
40+
41+
def __init__(self) -> None:
42+
super().__init__()
43+
self._dep_token: Optional[ProxyValue] = None
44+
self._next_dep_token_index: Optional[int] = None
45+
46+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
47+
# Early return if no non-functional assertions.
48+
if not any(
49+
n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
50+
for n in graph_module.graph.nodes
51+
):
52+
return PassResult(graph_module=graph_module, modified=False)
53+
54+
gm = copy.deepcopy(graph_module)
55+
self._dep_token = None
56+
self._next_dep_token_index = None
57+
return super().call(gm)
58+
59+
def call_operator(
60+
self,
61+
op: OpOverload,
62+
args: Tuple[Argument, ...],
63+
kwargs: Dict[str, Argument],
64+
meta: NodeMetadata,
65+
) -> ProxyValue:
66+
if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
67+
return super().call_operator(op, args, kwargs, meta)
68+
69+
if self._dep_token is None:
70+
self._dep_token = super().call_operator(
71+
aten._make_dep_token,
72+
args=(),
73+
kwargs={},
74+
meta=self._create_dummy_node_metadata(),
75+
)
76+
self._dep_token.node.name = "dep_token0"
77+
self._next_dep_token_index = 1
78+
79+
self._dep_token = super().call_operator(
80+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
81+
args=args,
82+
kwargs={**kwargs, "dep_token": self._dep_token},
83+
meta=meta,
84+
)
85+
assert self._next_dep_token_index is not None
86+
self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
87+
self._next_dep_token_index += 1
88+
89+
return self._dep_token
90+
91+
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
92+
assert self._dep_token is not None
93+
94+
return super().output(results=(*results, self._dep_token), meta=meta)

torch/_meta_registrations.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,34 @@ def assert_async_meta(val, assert_msg):
376376
return
377377

378378

379+
@register_meta(aten._make_dep_token.default)
380+
def make_dep_token(
381+
*,
382+
dtype=None,
383+
layout=None,
384+
device=None,
385+
pin_memory=None,
386+
memory_format=None,
387+
):
388+
return torch.empty([], device="meta")
389+
390+
379391
@register_meta(aten.sym_constrain_range.default)
380392
def sym_constrain_range(size, min, max):
381393
constrain_range(size, min=min, max=max)
382394

383395

396+
@register_meta(aten._functional_sym_constrain_range.default)
397+
def functional_sym_constrain_range(size, min, max, dep_token):
398+
aten.sym_constrain_range(size, min=min, max=max)
399+
return dep_token
400+
401+
402+
@register_meta(aten._functional_assert_async.msg)
403+
def functional_assert_async_meta(val, assert_msg, dep_token):
404+
return dep_token
405+
406+
384407
# From aten/src/ATen/native/LinearAlgebraUtils.h
385408
def squareCheckInputs(self: Tensor, f_name: str):
386409
assert (

0 commit comments

Comments
 (0)