Skip to content

Commit e204dd5

Browse files
janselfacebook-github-bot
authored andcommitted
[inductor] Rewrite Triton templates + epilogue fusion (retry) (pytorch#91575)
Summary: This reverts commit 94262ef to reland pytorch#91105 / pytorch#90738. Fixes pytorch/torchdynamo#2015 cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire Pull Request resolved: pytorch#91575 Reviewed By: ngimel Differential Revision: D42304332 Pulled By: jansel fbshipit-source-id: 1eefc7320da5de7544d048c5b7ea8716930f31cf
1 parent 7f2b5ea commit e204dd5

34 files changed

+1584
-1956
lines changed

benchmarks/dynamo/microbenchmarks/microbench.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main():
139139
if args.verbose:
140140
torch._inductor.config.debug = True
141141

142-
torch._inductor.config.triton.autotune = True
142+
torch._inductor.config.triton.autotune_pointwise = True
143143

144144
rows = []
145145
for model in (MicroBenchmarks.sum,):

benchmarks/dynamo/torchbench.py

+6
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def setup_torchbench_cwd():
118118
"tacotron2",
119119
}
120120

121+
REQUIRE_HIGHER_FP16_TOLERANCE = {
122+
"drq",
123+
}
124+
121125
REQUIRE_COSINE_TOLERACE = {
122126
# Just keeping it here even though its empty, if we need this in future.
123127
}
@@ -335,6 +339,8 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
335339
cosine = self.args.cosine
336340
# Increase the tolerance for torch allclose
337341
if self.args.float16 or self.args.amp:
342+
if name in REQUIRE_HIGHER_FP16_TOLERANCE:
343+
return 1e-2, cosine
338344
return 1e-3, cosine
339345
if is_training and current_device == "cuda":
340346
tolerance = 1e-3

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,6 @@ def main():
11561156
'include/THH/generic/*.h',
11571157
'include/sleef.h',
11581158
"_inductor/codegen/*.h",
1159-
"_inductor/codegen/*.j2",
11601159
'share/cmake/ATen/*.cmake',
11611160
'share/cmake/Caffe2/*.cmake',
11621161
'share/cmake/Caffe2/public/*.cmake',
+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Owner(s): ["module: inductor"]
2+
import functools
3+
import logging
4+
from unittest.mock import patch
5+
6+
import torch
7+
import torch._dynamo.config as dynamo_config
8+
import torch._inductor.config as inductor_config
9+
import torch._inductor.select_algorithm as select_algorithm
10+
import torch.nn.functional as F
11+
from torch._dynamo.test_case import run_tests, TestCase
12+
from torch._dynamo.utils import counters
13+
from torch.testing._internal.common_utils import IS_LINUX
14+
from torch.testing._internal.inductor_utils import HAS_CUDA
15+
16+
torch.backends.cuda.matmul.allow_tf32 = False
17+
18+
19+
def patches(fn):
20+
def skip_cache(self, key, generate):
21+
return generate()
22+
23+
for patcher in [
24+
patch.object(dynamo_config, "log_level", logging.INFO),
25+
patch.object(dynamo_config, "verbose", True),
26+
patch.object(inductor_config, "debug", True),
27+
patch.object(inductor_config, "max_autotune", True),
28+
patch.object(inductor_config, "epilogue_fusion", True),
29+
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
30+
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
31+
]:
32+
fn = patcher(fn)
33+
34+
@functools.wraps(fn)
35+
def wrapped(*args, **kwargs):
36+
counters.clear()
37+
torch.manual_seed(12345)
38+
assert (
39+
not torch.backends.cuda.matmul.allow_tf32
40+
), "correctness testing is allergic to tf32"
41+
return fn(*args, **kwargs)
42+
43+
return wrapped
44+
45+
46+
class TestSelectAlgorithm(TestCase):
47+
@patches
48+
def test_linear_relu(self):
49+
@torch.compile
50+
def foo(input, weight, bias):
51+
return F.relu(F.linear(input, weight, bias))
52+
53+
foo(
54+
torch.randn(64, 32, device="cuda"),
55+
torch.randn(16, 32, device="cuda"),
56+
torch.randn(16, device="cuda"),
57+
)
58+
# Autotuning checks correctness of each version
59+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
60+
# It would be nice to assert this got fused into a single kernel, but that
61+
# only happens if we select a triton template (and not aten).
62+
63+
@patches
64+
def test_addmm(self):
65+
@torch.compile
66+
def foo(input, weight, bias):
67+
return torch.addmm(bias, input, weight)
68+
69+
foo(
70+
torch.randn(20, 33, device="cuda"),
71+
torch.randn(33, 16, device="cuda"),
72+
torch.randn(20, 16, device="cuda"),
73+
)
74+
# Autotuning checks correctness of each version
75+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
76+
77+
@patches
78+
def test_mm(self):
79+
@torch.compile
80+
def foo(a, b):
81+
return torch.mm(a, b)
82+
83+
foo(
84+
torch.randn(8, 32, device="cuda"),
85+
torch.randn(32, 8, device="cuda"),
86+
)
87+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
88+
89+
@patches
90+
def test_mm_skip(self):
91+
@torch.compile
92+
def foo(a, b):
93+
return torch.mm(a, b)
94+
95+
foo(
96+
torch.randn(8, 32, device="cuda", dtype=torch.float64),
97+
torch.randn(32, 8, device="cuda", dtype=torch.float64),
98+
)
99+
# float64 not supported by tl.dot()
100+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
101+
102+
@patches
103+
def test_bmm(self):
104+
@torch.compile
105+
def foo(a, b):
106+
return torch.bmm(a, b)
107+
108+
foo(
109+
torch.randn(2, 8, 32, device="cuda"),
110+
torch.randn(2, 32, 8, device="cuda"),
111+
)
112+
# Autotuning checks correctness of each version
113+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
114+
115+
@patches
116+
def test_mm_not_even_k(self):
117+
@torch.compile
118+
def foo(a, b):
119+
return torch.mm(a, b)
120+
121+
foo(
122+
torch.randn(11, 22, device="cuda"),
123+
torch.randn(22, 33, device="cuda"),
124+
)
125+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
126+
127+
@patches
128+
def test_baddbmm(self):
129+
@torch.compile
130+
def foo(a, b, c):
131+
return torch.baddbmm(c, a, b)
132+
133+
foo(
134+
torch.randn(2, 8, 32, device="cuda"),
135+
torch.randn(2, 32, 8, device="cuda"),
136+
torch.randn(2, 1, 8, device="cuda"),
137+
)
138+
# Autotuning checks correctness of each version
139+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
140+
141+
142+
if __name__ == "__main__":
143+
from torch._inductor.utils import is_big_gpu
144+
145+
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
146+
run_tests()

test/inductor/test_torchinductor.py

+3-89
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
8080
)
8181

82-
torch._inductor.config.triton.autotune = False # too slow
82+
torch._inductor.config.triton.autotune_pointwise = False # too slow
8383

8484

8585
# For OneDNN bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw,
@@ -2505,76 +2505,6 @@ def fn(x, y):
25052505
self.assertEqual(a.stride(), c.stride())
25062506
self.assertEqual(c.stride()[2], 1)
25072507

2508-
@requires_cuda()
2509-
@patch.object(config.triton, "convolution", "triton")
2510-
@patch.object(config.triton, "dense_indexing", "True")
2511-
def test_triton_conv(self):
2512-
@torch._dynamo.optimize("inductor", nopython=True)
2513-
def triton_conv(
2514-
x,
2515-
w,
2516-
bias,
2517-
stride,
2518-
padding,
2519-
dilation,
2520-
groups,
2521-
):
2522-
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
2523-
return y
2524-
2525-
stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
2526-
dtype = torch.float32
2527-
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
2528-
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
2529-
bias = torch.randn((32), dtype=dtype, device=self.device)
2530-
2531-
y = triton_conv(x, w, bias, stride, padding, dilation, groups)
2532-
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
2533-
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
2534-
2535-
@requires_cuda()
2536-
@patch.object(config.triton, "convolution", "autotune")
2537-
@patch.object(config.triton, "dense_indexing", "True")
2538-
def test_conv_autotune(self):
2539-
@torch._dynamo.optimize("inductor", nopython=True)
2540-
def triton_conv(
2541-
x,
2542-
w,
2543-
bias,
2544-
stride,
2545-
padding,
2546-
dilation,
2547-
groups,
2548-
):
2549-
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
2550-
return y
2551-
2552-
stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
2553-
dtype = torch.float32
2554-
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
2555-
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
2556-
bias = torch.randn((32), dtype=dtype, device=self.device)
2557-
2558-
y = triton_conv(x, w, bias, stride, padding, dilation, groups)
2559-
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
2560-
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
2561-
2562-
@patch.object(config.triton, "mm", "triton")
2563-
def test_triton_mm2(self):
2564-
@torch._dynamo.optimize("inductor", nopython=True)
2565-
def fn(x, y):
2566-
return torch.relu(torch.mm(x, y))
2567-
2568-
N = 1024
2569-
a = torch.randn([N, N], device=self.device, dtype=torch.float32)
2570-
b = torch.randn([N, N], device=self.device, dtype=torch.float32)
2571-
c1 = torch.relu(torch.mm(a, b))
2572-
torch._inductor.metrics.reset()
2573-
c = fn(a, b)
2574-
assert torch.allclose(c1, c, atol=1e-3, rtol=1e-3)
2575-
if self.device == "cuda":
2576-
assert torch._inductor.metrics.generated_kernel_count == 1
2577-
25782508
def test_std(self):
25792509
def fn(x):
25802510
return (
@@ -4560,12 +4490,6 @@ def fn(a, b):
45604490
)
45614491
expected_kernel = 0
45624492
# codegen mm kernel from template
4563-
if config.triton.mm != "aten" and self.device == "cuda":
4564-
expected_kernel = 1
4565-
if config.triton.mm == "autotune":
4566-
self.assertLessEqual(
4567-
torch._inductor.metrics.generated_kernel_count, expected_kernel
4568-
)
45694493
self.assertEqual(
45704494
torch._inductor.metrics.generated_kernel_count, expected_kernel
45714495
)
@@ -4641,15 +4565,6 @@ def run(x):
46414565
result.sum().backward()
46424566

46434567
expected_kernel = 4
4644-
if config.triton.mm != "aten" and self.device == "cuda":
4645-
# fwd: 2 * (mm+dropout) kernels = 2 kernels
4646-
# bwd: dropout + (mm) + 2 * (mm+dropout) kernels = 4 kernels
4647-
# expect 2 + 4 = 6 kernels
4648-
expected_kernel = 6
4649-
if config.triton.mm == "autotune":
4650-
self.assertLessEqual(
4651-
torch._inductor.metrics.generated_kernel_count, expected_kernel
4652-
)
46534568
self.assertEqual(
46544569
torch._inductor.metrics.generated_kernel_count, expected_kernel
46554570
)
@@ -4979,7 +4894,6 @@ def fn(x, y):
49794894
inputs = (inputs[1], inputs[0])
49804895
self.assertTrue(same(opt(*inputs), fn(*inputs)))
49814896

4982-
@patch.object(config.triton, "mm", "aten")
49834897
def test_list_clearing(self):
49844898

49854899
if self.device == "cpu":
@@ -5685,7 +5599,7 @@ def forward(self, view, reshape_2):
56855599
res = opt_mod(*args)
56865600
self.assertTrue(same(ref, res))
56875601

5688-
@patch.object(config.triton, "autotune", True)
5602+
@patch.object(config.triton, "autotune_pointwise", True)
56895603
def test_inplace_add_alpha_autotune(self):
56905604
def fn(x, y):
56915605
aten.add_.Tensor(x, y, alpha=0.55)
@@ -5703,7 +5617,7 @@ def fn(x, y):
57035617
fn_compiled([x3, y])
57045618
assert same(x2, x3)
57055619

5706-
@patch.object(config.triton, "autotune", True)
5620+
@patch.object(config.triton, "autotune_pointwise", True)
57075621
def test_inplace_buffer_autotune(self):
57085622
def foo(x, y, z):
57095623
a = x @ y

torch/_dynamo/testing.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,12 @@ def _fn(*args, **kwargs):
242242
return _fn
243243

244244

245-
def rand_strided(size, stride, dtype=torch.float32, device="cpu"):
246-
needed_size = sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
245+
def rand_strided(size, stride, dtype=torch.float32, device="cpu", extra_size=0):
246+
needed_size = (
247+
sum((shape - 1) * stride for shape, stride in zip(size, stride))
248+
+ 1
249+
+ extra_size
250+
)
247251
if dtype.is_floating_point:
248252
buffer = torch.randn(needed_size, dtype=dtype, device=device)
249253
else:

0 commit comments

Comments
 (0)