Skip to content

Commit f686d82

Browse files
committed
[NVFuser] Upstream push 0714
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. Indexing refactor -> Remove reference tensor in predicate indexing logic 2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension 3. Grouping grid allreduces across iterations 4. Swizzle op formulation for non-affine swizzles 5. Use scheduler_utils to cache inputs and outputs in schedulePointwise - scheduler refactor 1. New compute at interface - transformation propagation refactor on MaxInfoSpanningTree 1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. 2. Optimization to skip Transform propagator 3. SpanningTreePrinter for debugging - parser update 1. Fixes `div` 2. Added `_to_copy` 3. Broadcast in dim with expand to support expanding to concrete size 4. Dropout prob extremal patch - executor patch on caching strides for output allocation Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` 3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818) 4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815) 3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811) 03180aa improve broadcast resolution (csarofeen#1792) bee6c69 bug fix (csarofeen#1819) 4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812) de6b7ca Fix negative position in InlinePropagator (csarofeen#1813) 10a996c Remove redundant check in schedulePointwise (csarofeen#1810) acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441) 3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809) 037a75a Dropout prob extremal patch (csarofeen#1804) 282c429 spam nvrtc options (csarofeen#1783) 3ba6a5f Broadcast in dim with expand (csarofeen#1794) fd4be12 remove dead indexing code (csarofeen#1806) fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805) 025c840 Grouping grid allreduces across iterations (csarofeen#1755) 37c579e Temporarily disable test requring large shared memory. (csarofeen#1802) 5f375d0 More cleanup on InlinePropagator (csarofeen#1800) 8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784) f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554) 76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756) ef04f6c Coding style cleanups (csarofeen#1798) 38c7f3c InlinePropagator please don't replay (csarofeen#1797) 3f2c263 validateDomain in TransformPropagator (csarofeen#1796) c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795) d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793) 45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791) 28cbaf9 New compute at interface (csarofeen#1743) 635ebfc Add SpanningTreePrinter (csarofeen#1786) 59f3c32 Output allocate patch (csarofeen#1790) fe93bf5 Transform propagator skip replay when possible (csarofeen#1782) ebf23a5 Fix isIntegralType error msg (csarofeen#1789) 0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787) 33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776) 86f46aa Fix div(Val, TensorView) (csarofeen#1778) d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781) ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761) ``` [ghstack-poisoned]
1 parent 2fb2740 commit f686d82

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+6375
-2049
lines changed

build_variables.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ libtorch_nvfuser_runtime_sources = [
3232
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
3333
"torch/csrc/jit/codegen/cuda/runtime/index_utils.cu",
3434
"torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu",
35+
"torch/csrc/jit/codegen/cuda/runtime/swizzle.cu",
3536
"torch/csrc/jit/codegen/cuda/runtime/memory.cu",
3637
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
3738
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
@@ -643,6 +644,7 @@ libtorch_cuda_core_sources = [
643644
"torch/csrc/autograd/functions/comm.cpp",
644645
"torch/csrc/jit/codegen/cuda/arith.cpp",
645646
"torch/csrc/jit/codegen/cuda/compute_at.cpp",
647+
"torch/csrc/jit/codegen/cuda/inline_propagator.cpp",
646648
"torch/csrc/jit/codegen/cuda/compute_at_map.cpp",
647649
"torch/csrc/jit/codegen/cuda/codegen.cpp",
648650
"torch/csrc/jit/codegen/cuda/contiguity.cpp",
@@ -658,7 +660,6 @@ libtorch_cuda_core_sources = [
658660
"torch/csrc/jit/codegen/cuda/grouped_reduction.cpp",
659661
"torch/csrc/jit/codegen/cuda/index_compute.cpp",
660662
"torch/csrc/jit/codegen/cuda/lower_index_compute.cpp",
661-
"torch/csrc/jit/codegen/cuda/index_reference_replay.cpp",
662663
"torch/csrc/jit/codegen/cuda/instrumentation.cpp",
663664
"torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
664665
"torch/csrc/jit/codegen/cuda/ir_builder.cpp",

test/test_jit_cuda_fuser.py

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def is_pre_volta():
8383

8484
TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported()
8585

86+
TEST_LARGE_TENSOR = RUN_NVFUSER
87+
if RUN_NVFUSER:
88+
torch.ones(1).cuda() # initialize cuda context
89+
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
90+
8691
class CudaFuserTestOptions():
8792
def __init__(self):
8893
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
@@ -184,23 +189,27 @@ def tearDown(self):
184189
self.cuda_fuser_options.restore()
185190
super(TestCudaFuser, self).tearDown()
186191

187-
def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1):
188-
torch.cuda.manual_seed_all(123)
189-
jit_o = jit_op(*args)
190-
torch.cuda.manual_seed_all(123)
192+
def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1):
193+
seed = 123
194+
torch.cuda.manual_seed_all(seed)
191195
jit_o = jit_op(*args)
192-
torch.cuda.manual_seed_all(123)
193-
o = op(*args)
194196

195-
if type(jit_o) is torch.Tensor:
196-
jit_o = [jit_o, ]
197-
o = [o, ]
197+
for i in range(check_runs):
198+
torch.cuda.manual_seed_all(seed + i)
199+
jit_o = jit_op(*args)
200+
torch.cuda.manual_seed_all(seed + i)
201+
o = op(*args)
202+
203+
if type(jit_o) is torch.Tensor:
204+
jit_o = [jit_o, ]
205+
o = [o, ]
206+
207+
for oo, jit_oo in zip(o, jit_o):
208+
self.assertEqual(oo.dtype, jit_oo.dtype)
209+
self.assertEqual(oo, jit_oo)
210+
if check_stride:
211+
self.assertEqual(oo.stride(), jit_oo.stride())
198212

199-
for oo, jit_oo in zip(o, jit_o):
200-
self.assertEqual(oo.dtype, jit_oo.dtype)
201-
self.assertEqual(oo, jit_oo)
202-
if check_stride:
203-
self.assertEqual(oo.stride(), jit_oo.stride())
204213
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True)
205214

206215
def _run_training_helper(self, jit_op, op, grads, *args):
@@ -2563,13 +2572,14 @@ def t(x: torch.Tensor, p: float, train: bool):
25632572

25642573
self._run_helper(t_jit, t, x, 0.15, False)
25652574

2575+
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
25662576
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
25672577
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
25682578
"Requires fusion optimization pass to be effective")
25692579
def test_dropout_train_nograd_fusion(self):
25702580
dtype = torch.float
25712581
device = "cuda"
2572-
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
2582+
x = torch.randn([64, 128, 1024], dtype=dtype, device=device)
25732583

25742584
def t(x: torch.Tensor, p: float, train: bool):
25752585
o = torch.nn.functional.dropout(x, p, training=train)
@@ -2578,7 +2588,8 @@ def t(x: torch.Tensor, p: float, train: bool):
25782588

25792589
t_jit = torch.jit.script(t)
25802590

2581-
self._run_helper(t_jit, t, x, 0.0, True)
2591+
self._run_helper(t_jit, t, x, 0.0, True, check_runs=20)
2592+
self._run_helper(t_jit, t, x, 1.0, True, check_runs=20)
25822593

25832594
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
25842595
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
@@ -4391,6 +4402,33 @@ def t(x):
43914402
t_jit = torch.jit.script(t)
43924403
self._run_helper(t_jit, t, x)
43934404

4405+
4406+
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
4407+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
4408+
"Requires fusion optimization pass to be effective")
4409+
def test_to_copy(self):
4410+
x = torch.randn(4, 2, device="cuda")
4411+
4412+
with nvfuser_singleton_fusion(True):
4413+
def t(x, dtype : torch.dtype):
4414+
o = torch.ops.aten._to_copy(x, dtype=dtype)
4415+
return o
4416+
4417+
t.__disable_jit_function_caching__ = True
4418+
4419+
t_jit = torch.jit.script(t)
4420+
for dtype in [torch.float16, torch.bool, torch.float64]:
4421+
self._run_helper(t_jit, t, x, dtype)
4422+
4423+
def t_none(x):
4424+
with torch.jit.strict_fusion():
4425+
o = torch.ops.aten._to_copy(x, dtype=None)
4426+
return o
4427+
4428+
t_jit_none = torch.jit.script(t_none)
4429+
self._run_helper(t_jit_none, t_none, x)
4430+
4431+
43944432
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now")
43954433
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
43964434
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
@@ -4752,6 +4790,35 @@ def t(x):
47524790
jit_t = torch.jit.script(t)
47534791
self._run_helper(jit_t, t, x)
47544792

4793+
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
4794+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
4795+
"Requires fusion optimization pass to be effective")
4796+
def test_issue_1785(self):
4797+
class Fusion(torch.nn.Module):
4798+
def __init__(self):
4799+
super(Fusion, self).__init__()
4800+
4801+
def forward(self, x, a, b):
4802+
out = torch.mul(x.unsqueeze(-1), a)
4803+
out = out + b
4804+
return out
4805+
4806+
x = torch.randn(1024, 192, 3, device='cuda')
4807+
a = torch.randn(3, 128, device='cuda')
4808+
b = torch.randn(3, 128, device='cuda')
4809+
4810+
model = Fusion()
4811+
jit_model = torch.jit.script(model)
4812+
4813+
with torch.jit.fuser('fuser2'):
4814+
for _ in range(4):
4815+
out_ref = model(x, a, b)
4816+
out_jit = jit_model(x, a, b)
4817+
4818+
out_ref = model(x, a, b)
4819+
out_jit = jit_model(x, a, b)
4820+
self.assertTrue(self._compare("comparing output failed", out_ref, out_jit, 1e-5))
4821+
47554822
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
47564823
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
47574824
"Requires fusion optimization pass to be effective")

torch/csrc/jit/codegen/cuda/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ There're a few debug dump that could be turned on via environment variables. Loo
187187
1. `dump_eff_bandwidth`: print out effective bandwidth of each generated kernel. This naively measure the kernel time divided by I/O buffer size and is a good/simple metric of performance for bandwidth bound kernels
188188
2. `cuda_kernel`: print out generated cuda kernels
189189
3. `launch_param`: print out launch config of generated kernels
190-
4. `print_args`: print out input output tensors of executed codegen kernels
190+
4. `kernel_args`: print out input/output/buffer tensors of all executed codegen kernels, note that for buffers, we indicate whether they are zero-initialized, which hints on an extra kernel to fill the tensor before codegen kernels.
191191
192192
### FAQs
193193

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ TensorView* unaryOp(
458458
}
459459

460460
NVFUSER_DEFINE_UNARY_OP(set, Set)
461-
NVFUSER_DEFINE_UNARY_OP(randlike, RandLike)
462461
NVFUSER_DEFINE_UNARY_OP(ceil, Ceil)
463462
NVFUSER_DEFINE_UNARY_OP(floor, Floor)
464463
NVFUSER_DEFINE_UNARY_OP(frac, Frac)
@@ -469,6 +468,30 @@ NVFUSER_DEFINE_UNARY_OP(silu, Silu)
469468
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
470469
#undef NVFUSER_DEFINE_UNARY_OP
471470

471+
Val* randlike(Val* v) {
472+
TORCH_CHECK(
473+
isFloatingPointType(v->dtype()),
474+
"input must have floating point type, but got ",
475+
v->dtype());
476+
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
477+
return where(
478+
eq(rand_vals, IrBuilder::create<Double>(1.0)),
479+
IrBuilder::create<Double>(0.0),
480+
rand_vals);
481+
}
482+
483+
TensorView* randlike(TensorView* v) {
484+
TORCH_CHECK(
485+
isFloatingPointType(v->dtype()),
486+
"input must have floating point type, but got ",
487+
v->dtype());
488+
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
489+
return where(
490+
eq(rand_vals, IrBuilder::create<Double>(1.0)),
491+
IrBuilder::create<Double>(0.0),
492+
rand_vals);
493+
}
494+
472495
Val* bitwise_not(Val* v) {
473496
TORCH_CHECK(
474497
isIntegralType(v->dtype()) || v->dtype() == DataType::Bool,

0 commit comments

Comments
 (0)