Skip to content

Commit 81f351a

Browse files
peterbell10pytorchmergebot
authored andcommitted
[inductor] Prevent blowup in inner_fn_str and extract_read_writes (pytorch#88933)
Currently the default `ops` handler expects strings as arguments and just formats them into a function call template string. For complex expressions, this can lead to exponential growth in terms. Say for example you have: ```python def fn(a): for _ in range(3) a = ops.mul(a, a) return a ``` You might expect `inner_fn_str` to contain 1 load and 3 multiplies, but instead you find 8 loads and 7 multiplies: ```python load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) ``` This type of blowup is present in the lowering for `max_pool2d_with_indices_backward` which in #pytorch/torchdynamo#1352 was reported to have caused the entire compilation to hang. This PR fixes the issue by formatting the string as a series of assignments to variables, so for the example above, we now get: ``` tmp0 = load(arg_0, i0) tmp1 = tmp0 * tmp0 tmp2 = tmp1 * tmp1 tmp3 = tmp2 * tmp2 return tmp3 ``` Which corresponds to sequence of `ops` calls made. Pull Request resolved: pytorch#88933 Approved by: https://github.com/jansel
1 parent c4718e9 commit 81f351a

File tree

7 files changed

+200
-154
lines changed

7 files changed

+200
-154
lines changed

test/inductor/test_torchinductor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5354,7 +5354,7 @@ def fn(x1, x2):
53545354
traced = make_fx(fn)(x1, x2)
53555355
compiled = compile_fx_inner(traced, [x1, x2])
53565356
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
5357-
assert metrics.generated_cpp_vec_kernel_count == 1
5357+
assert metrics.generated_cpp_vec_kernel_count == 0
53585358

53595359
torch._dynamo.reset()
53605360
metrics.reset()

test/inductor/test_torchinductor_opinfo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,6 @@ def process(device_type):
363363
"nn.functional.local_response_norm": {f16},
364364
"outer": {f16},
365365
"quantile": {f32, f64},
366-
"scatter_reduce.amax": {f16, f32, f64},
367-
"scatter_reduce.amin": {f16, f32, f64},
368366
"tanh": {f16},
369367
}
370368

torch/_inductor/codegen/common.py

Lines changed: 13 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@
22
import contextlib
33
import itertools
44
import logging
5-
import math
65
import re
7-
import textwrap
86
import typing
97
from collections import namedtuple
10-
from io import StringIO
118
from itertools import chain
129

1310
import sympy
1411
from sympy.printing.printer import Printer
1512

1613
from .. import metrics
17-
from ..utils import free_symbol_startswith, sympy_dot, sympy_subs, sympy_symbol, unique
14+
from ..utils import (
15+
DeferredLineBase,
16+
free_symbol_startswith,
17+
IndentedBuffer,
18+
sympy_dot,
19+
sympy_subs,
20+
sympy_symbol,
21+
unique,
22+
)
1823
from ..virtualized import ops, V
1924

2025
log = logging.getLogger(__name__)
@@ -125,102 +130,12 @@ def remainder(a, b):
125130
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
126131

127132

128-
class IndentedBuffer:
129-
tabwidth = 4
130-
131-
def __init__(self, initial_indent=0):
132-
self._lines = []
133-
self._indent = initial_indent
134-
135-
def getvalue(
136-
self,
137-
):
138-
buf = StringIO()
139-
for line in self._lines:
140-
if isinstance(line, DeferredLine):
141-
line = line()
142-
if line is None:
143-
continue
144-
assert isinstance(line, str)
145-
buf.write(line)
146-
buf.write("\n")
147-
return buf.getvalue()
148-
149-
def getrawvalue(self):
150-
buf = StringIO()
151-
for line in self._lines:
152-
if isinstance(line, DeferredLine):
153-
line = line()
154-
if line is None:
155-
continue
156-
assert isinstance(line, str)
157-
# backslash implies line continuation
158-
if line.endswith("\\"):
159-
buf.write(line[:-1])
160-
else:
161-
buf.write(line)
162-
buf.write("\n")
163-
return buf.getvalue()
164-
165-
def clear(self):
166-
self._lines.clear()
167-
168-
def __bool__(self):
169-
return bool(self._lines)
170-
171-
def prefix(self):
172-
return " " * (self._indent * self.tabwidth)
173-
174-
def writeline(self, line):
175-
if isinstance(line, DeferredLine):
176-
self._lines.append(line.with_prefix(self.prefix()))
177-
elif line.strip():
178-
self._lines.append(f"{self.prefix()}{line}")
179-
else:
180-
self._lines.append("")
181-
182-
def writelines(self, lines):
183-
for line in lines:
184-
self.writeline(line)
185-
186-
def indent(self, offset=1):
187-
@contextlib.contextmanager
188-
def ctx():
189-
self._indent += offset
190-
yield
191-
self._indent -= offset
192-
193-
return ctx()
194-
195-
def splice(self, other_code, strip=False):
196-
if isinstance(other_code, IndentedBuffer):
197-
dedent = float("inf")
198-
for line in other_code._lines:
199-
if line:
200-
dedent = min(dedent, len(line) - len(line.lstrip()))
201-
if math.isinf(dedent):
202-
dedent = 0
203-
for line in other_code._lines:
204-
IndentedBuffer.writeline(self, line[dedent:])
205-
else:
206-
other_code = textwrap.dedent(other_code)
207-
if strip:
208-
other_code = other_code.lstrip()
209-
if not other_code:
210-
return
211-
other_code = other_code.rstrip()
212-
for line in other_code.split("\n"):
213-
self.writeline(line)
214-
215-
216-
class DeferredLine:
133+
class DeferredLine(DeferredLineBase):
217134
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
218135

219136
def __init__(self, name, line):
220-
if not line.strip():
221-
line = ""
137+
super().__init__(line)
222138
self.name = name
223-
self.line = line
224139

225140
def __call__(self):
226141
if (
@@ -230,20 +145,8 @@ def __call__(self):
230145
return self.line
231146
return None
232147

233-
def with_prefix(self, prefix):
234-
return DeferredLine(self.name, f"{prefix}{self.line}")
235-
236-
def lstrip(self):
237-
return DeferredLine(self.name, self.line.lstrip())
238-
239-
def __getitem__(self, index):
240-
return DeferredLine(self.name, self.line[index])
241-
242-
def __bool__(self):
243-
return bool(self.line)
244-
245-
def __len__(self):
246-
return len(self.line)
148+
def _new_line(self, line):
149+
return DeferredLine(self.name, line)
247150

248151

249152
class DeferredIndentedBuffer(IndentedBuffer):

torch/_inductor/dependencies.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import sympy
99

10-
from . import config
1110
from .codegen.common import index_prevent_reordering
1211
from .utils import (
1312
get_dtype_size,
@@ -165,24 +164,15 @@ def merge(self, other):
165164
)
166165

167166

168-
class RecordLoadStore(V.MockHandler): # type: ignore[name-defined]
167+
class _RecordLoadStoreInner(V.MockHandler):
169168
def __init__(self, var_ranges: VarRanges, normalize: bool):
170-
super(RecordLoadStore, self).__init__()
169+
super().__init__()
171170
self._reads: Set[MemoryDep] = set()
172171
self._writes: Set[MemoryDep] = set()
173172
self._index_exprs: Set[IndexExprDep] = set()
174173
self._var_ranges: VarRanges = var_ranges
175174
self._normalize: bool = normalize
176175

177-
# Truncate the expr str by a threshold to prevent it's too long
178-
# and cause process hanging. The result is not used.
179-
# https://github.com/pytorch/torchdynamo/issues/1352
180-
@staticmethod
181-
def truncate_expr(expr):
182-
if len(expr) > config.realize_bytes_threshold:
183-
expr = f"{expr[:config.realize_bytes_threshold]}..."
184-
return expr
185-
186176
def canonicalize(
187177
self, index: sympy.Expr
188178
) -> Tuple[sympy.Expr, Tuple[sympy.Expr, ...]]:
@@ -230,6 +220,14 @@ def index_expr(self, index: sympy.Expr, dtype) -> str:
230220
return f"index_expr({sympy_str(index)}, {dtype})"
231221

232222

223+
class RecordLoadStore(V.KernelFormatterHandler):
224+
def __init__(self, var_ranges: VarRanges, normalize: bool):
225+
parent_handler = _RecordLoadStoreInner(
226+
var_ranges=var_ranges, normalize=normalize
227+
)
228+
super().__init__(parent_handler=parent_handler)
229+
230+
233231
def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
234232
cnt = itertools.count()
235233
var_ranges: VarRanges = collections.OrderedDict()
@@ -279,8 +277,13 @@ def extract_read_writes(
279277
else:
280278
range_vars = [*itertools.chain(*args)]
281279

280+
inner = rw.parent_handler
282281
return ReadWrites(
283-
set(rw._reads), set(rw._writes), rw._index_exprs, range_vars, var_ranges
282+
set(inner._reads),
283+
set(inner._writes),
284+
inner._index_exprs,
285+
range_vars,
286+
var_ranges,
284287
)
285288

286289

torch/_inductor/ir.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,12 @@ def _index(ranges, prefix="i"):
357357

358358
@cache_on_self
359359
def inner_fn_str(self):
360-
try:
361-
with V.set_ops_handler(V.MockHandler()), patch.object(
362-
FlexibleLayout, "allow_indexing", True
363-
):
364-
return str(self.inner_fn(self._index(self.ranges)))
365-
except Exception as e:
366-
return f"inner_fn(): {e}"
360+
formatter = V.KernelFormatterHandler(V.MockHandler())
361+
with V.set_ops_handler(formatter), patch.object(
362+
FlexibleLayout, "allow_indexing", True
363+
):
364+
result = self.inner_fn(self._index(self.ranges))
365+
return formatter.getvalue(result)
367366

368367
def is_zero_elements(self):
369368
return any(r == 0 for r in self.ranges)
@@ -479,18 +478,15 @@ def index_length(self):
479478

480479
@cache_on_self
481480
def inner_fn_str(self):
482-
try:
483-
with V.set_ops_handler(V.MockHandler()), patch.object(
484-
FlexibleLayout, "allow_indexing", True
485-
):
486-
return str(
487-
self.inner_fn(
488-
self._index(self.ranges),
489-
self._index(self.reduction_ranges, "r"),
490-
)
491-
)
492-
except Exception as e:
493-
return f"inner_fn(): {e}"
481+
formatter = V.KernelFormatterHandler(MockHandler())
482+
with V.set_ops_handler(formatter), patch.object(
483+
FlexibleLayout, "allow_indexing", True
484+
):
485+
result = self.inner_fn(
486+
self._index(self.ranges),
487+
self._index(self.reduction_ranges, "r"),
488+
)
489+
return formatter.getvalue(result)
494490

495491
def constant_to_device(self, device):
496492
"""Move this to a given device. Requires that all reads are to constants."""
@@ -3948,7 +3944,7 @@ def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
39483944
"""
39493945
heavy_ops = ["exp"] # a list of heavy ops
39503946
fn_str = loops.inner_fn_str()
3951-
return any([fn_str.startswith(op + "(") for op in heavy_ops])
3947+
return any([(op + "(") in fn_str for op in heavy_ops])
39523948

39533949
if (
39543950
users > 1

0 commit comments

Comments
 (0)