diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 2b05b8546b..a5bfa3d652 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -1655,6 +1655,15 @@ def fn(x): (torch.randn([64]),), ) + def test_flip(self): + def fn(x): + return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2 + + self.common( + fn, + (torch.randn([1, 2, 6, 6]),), + ) + def test_log2(self): def fn(x): return torch.log2(x), torch.log2(x + 1) - 2 diff --git a/torchinductor/codegen/cpp.py b/torchinductor/codegen/cpp.py index a60ee505b6..a8b7e4ed48 100644 --- a/torchinductor/codegen/cpp.py +++ b/torchinductor/codegen/cpp.py @@ -186,6 +186,10 @@ def exp(x): def sqrt(x): return f"std::sqrt({x})" + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + @staticmethod def pow(a, b): return f"std::pow({a}, {b})" diff --git a/torchinductor/codegen/triton.py b/torchinductor/codegen/triton.py index b339bd7a5e..ebcfca7ac6 100644 --- a/torchinductor/codegen/triton.py +++ b/torchinductor/codegen/triton.py @@ -159,6 +159,10 @@ def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op return f"tl.randn({seed}, {offset})" + @staticmethod + def rsqrt(x): + return f"tl.libdevice.rsqrt({x})" + @staticmethod def pow(a, b): return f"tl.libdevice.pow({a}, {b})" diff --git a/torchinductor/decomposition.py b/torchinductor/decomposition.py index 3b5b3a8356..02ca70e596 100644 --- a/torchinductor/decomposition.py +++ b/torchinductor/decomposition.py @@ -82,6 +82,7 @@ aten.tanh_backward, aten.threshold_backward, aten.transpose.int, + aten.tril.default, aten.upsample_nearest2d_backward, aten.upsample_bilinear2d.vec, ] @@ -125,11 +126,6 @@ def tanh(x): return 2.0 / (1.0 + torch.exp(-2.0 * x)) - 1.0 -@register_decomposition([aten.rsqrt]) -def rsqrt(x): - return torch.reciprocal(torch.sqrt(x)) - - @register_decomposition([aten.log2]) def log2(x): return torch.log(x) * (1.0 / math.log(2.0)) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index ecc98570f0..71810324ce 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -1857,6 +1857,28 @@ def accumulate(out_x, out_y, index_range1, index_range2=None): ) +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canoncalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + @register_lowering(aten.constant_pad_nd, type_promote=False) def constant_pad_nd(x, padding, fill_value=0): assert (len(padding) % 2) == 0 @@ -2653,6 +2675,7 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None): register_pointwise(aten.reciprocal) register_pointwise(aten.remainder) register_pointwise(aten.round) +register_pointwise(aten.rsqrt) register_pointwise(aten.sign) register_pointwise(aten.silu) register_pointwise(aten.ceil)