From 35a830fc8b9929f81a17c1f613762dad170fcf3f Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 24 Aug 2022 22:42:28 +0000 Subject: [PATCH 1/6] Add a few lowerings reland --- test/test_torchinductor.py | 9 +++++++++ torchinductor/codegen/cpp.py | 4 ++++ torchinductor/codegen/triton.py | 4 ++++ torchinductor/decomposition.py | 6 +----- torchinductor/lowering.py | 23 +++++++++++++++++++++++ 5 files changed, 41 insertions(+), 5 deletions(-) diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 74271020a0..fea0e0e3aa 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -1687,6 +1687,15 @@ def fn(x): (torch.randn([64]),), ) + def test_flip(self): + def fn(x): + return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2 + + self.common( + fn, + (torch.randn([1, 2, 6, 6]),), + ) + def test_log2(self): def fn(x): return torch.log2(x), torch.log2(x + 1) - 2 diff --git a/torchinductor/codegen/cpp.py b/torchinductor/codegen/cpp.py index 05bf2f9abb..d1b2ae0925 100644 --- a/torchinductor/codegen/cpp.py +++ b/torchinductor/codegen/cpp.py @@ -186,6 +186,10 @@ def exp(x): def sqrt(x): return f"std::sqrt({x})" + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + @staticmethod def pow(a, b): return f"std::pow({a}, {b})" diff --git a/torchinductor/codegen/triton.py b/torchinductor/codegen/triton.py index 0499751d75..253199fdce 100644 --- a/torchinductor/codegen/triton.py +++ b/torchinductor/codegen/triton.py @@ -159,6 +159,10 @@ def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op return f"tl.randn({seed}, {offset})" + @staticmethod + def rsqrt(x): + return f"tl.libdevice.rsqrt({x})" + @staticmethod def pow(a, b): return f"tl.libdevice.pow({a}, {b})" diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index d9703a1a7f..5c3300134e 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -83,6 +83,7 @@ aten.tanh_backward, aten.threshold_backward, aten.transpose.int, + aten.tril.default, aten.upsample_nearest2d_backward, aten.upsample_bilinear2d.vec, ] @@ -116,11 +117,6 @@ def tanh(x): return 2.0 / (1.0 + torch.exp(-2.0 * x)) - 1.0 -@register_decomposition([aten.rsqrt]) -def rsqrt(x): - return torch.reciprocal(torch.sqrt(x)) - - @register_decomposition([aten.log2]) def log2(x): return torch.log(x) * (1.0 / math.log(2.0)) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 28094d6988..3d5083efc2 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -1905,6 +1905,28 @@ def accumulate(out_x, out_y, index_range1, index_range2=None): ) +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canoncalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + @register_lowering(aten.constant_pad_nd, type_promote=False) def constant_pad_nd(x, padding, fill_value=0): assert (len(padding) % 2) == 0 @@ -2795,6 +2817,7 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None): register_pointwise(aten.reciprocal) register_pointwise(aten.remainder) register_pointwise(aten.round) +register_pointwise(aten.rsqrt) register_pointwise(aten.sign) register_pointwise(aten.silu) register_pointwise(aten.ceil) From 508e75ef2cfde4cd69b07c2183f11a584ff09b6d Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 25 Aug 2022 17:17:21 +0000 Subject: [PATCH 2/6] int promote --- test/test_torchinductor.py | 4 ++++ torchinductor/lowering.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index fea0e0e3aa..8040a477df 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -1686,6 +1686,10 @@ def fn(x): fn, (torch.randn([64]),), ) + self.common( + fn, + (torch.ones([64], dtype=torch.int),), + ) def test_flip(self): def fn(x): diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 3d5083efc2..36706a9cd4 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -2772,6 +2772,16 @@ def fn(*args): b if isinstance(b, Number) else to_dtype(b, dtype), ) +@register_lowering(aten.rsqrt) +def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + return make_pointwise(_rsqrt)(x) @register_lowering([aten.sum, prims.sum]) def sum_(x, axis=None, keepdims=False, *, dtype=None): @@ -2817,7 +2827,6 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None): register_pointwise(aten.reciprocal) register_pointwise(aten.remainder) register_pointwise(aten.round) -register_pointwise(aten.rsqrt) register_pointwise(aten.sign) register_pointwise(aten.silu) register_pointwise(aten.ceil) From 2d332e7b935b8aab94a022b59a74cd0360a94bed Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 25 Aug 2022 17:25:14 +0000 Subject: [PATCH 3/6] lint --- torchinductor/lowering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 36706a9cd4..92a5c15876 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -2772,6 +2772,7 @@ def fn(*args): b if isinstance(b, Number) else to_dtype(b, dtype), ) + @register_lowering(aten.rsqrt) def rsqrt(x): dtype = x.get_dtype() @@ -2783,6 +2784,7 @@ def _rsqrt(x): return make_pointwise(_rsqrt)(x) + @register_lowering([aten.sum, prims.sum]) def sum_(x, axis=None, keepdims=False, *, dtype=None): if ( From 6e7b69f58608d01748640dce3a7bc460b99afede Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 26 Aug 2022 16:07:33 +0000 Subject: [PATCH 4/6] disable rsqrt --- torchinductor/decomposition.py | 5 +++++ torchinductor/lowering.py | 22 ++++++++++++---------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index 5c3300134e..1c130ee07b 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -152,6 +152,11 @@ def special_erf(x): return sign * y +@register_decomposition([aten.rsqrt]) +def rsqrt(x): + return torch.reciprocal(torch.sqrt(x)) + + @register_decomposition([aten.rsub.Tensor, aten.rsub.Scalar]) def rsub(a, b): if isinstance(b, numbers.Number): diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 92a5c15876..cf8d6cab40 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -2773,16 +2773,18 @@ def fn(*args): ) -@register_lowering(aten.rsqrt) -def rsqrt(x): - dtype = x.get_dtype() - if is_integer_dtype(dtype) or is_boolean_dtype(dtype): - x = to_dtype(x, torch.get_default_dtype()) - - def _rsqrt(x): - return ops.rsqrt(x) - - return make_pointwise(_rsqrt)(x) +# TODO - enable builtin and disable decomp to lower to ptx instruction +# Causes compilation to not complete on timm_vision_transformers inference +# @register_lowering(aten.rsqrt) +# def rsqrt(x): +# dtype = x.get_dtype() +# if is_integer_dtype(dtype) or is_boolean_dtype(dtype): +# x = to_dtype(x, torch.get_default_dtype()) +# +# def _rsqrt(x): +# return ops.rsqrt(x) +# +# return make_pointwise(_rsqrt)(x) @register_lowering([aten.sum, prims.sum]) From 73e9b8f91d36626ee91bdd56e26427513afa5ab2 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 26 Aug 2022 18:02:18 +0000 Subject: [PATCH 5/6] not changing rsqrt anymore - fix int in follow up --- test/test_torchinductor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 8040a477df..fea0e0e3aa 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -1686,10 +1686,6 @@ def fn(x): fn, (torch.randn([64]),), ) - self.common( - fn, - (torch.ones([64], dtype=torch.int),), - ) def test_flip(self): def fn(x): From b35cdc59a9dde6d2b2aafce769dbf6d2e7de446b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 29 Aug 2022 15:01:57 +0000 Subject: [PATCH 6/6] lint --- torchinductor/decomposition.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index 470a380cf7..eb61c5ce9b 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -160,11 +160,6 @@ def special_erf(x): return sign * y -@register_decomposition([aten.rsqrt]) -def rsqrt(x): - return torch.reciprocal(torch.sqrt(x)) - - @register_decomposition([aten.rsub.Tensor, aten.rsub.Scalar]) def rsub(a, b): if isinstance(b, numbers.Number):