Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit ea455b7

Browse files
authored
add a few lowerings (#968)
* add a few lowerings * off by one
1 parent 5c18f71 commit ea455b7

File tree

5 files changed

+41
-5
lines changed

5 files changed

+41
-5
lines changed

test/test_torchinductor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,15 @@ def fn(x):
16671667
(torch.randn([64]),),
16681668
)
16691669

1670+
def test_flip(self):
1671+
def fn(x):
1672+
return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2
1673+
1674+
self.common(
1675+
fn,
1676+
(torch.randn([1, 2, 6, 6]),),
1677+
)
1678+
16701679
def test_log2(self):
16711680
def fn(x):
16721681
return torch.log2(x), torch.log2(x + 1) - 2

torchinductor/codegen/cpp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def exp(x):
186186
def sqrt(x):
187187
return f"std::sqrt({x})"
188188

189+
@staticmethod
190+
def rsqrt(x):
191+
return f"1 / std::sqrt({x})"
192+
189193
@staticmethod
190194
def pow(a, b):
191195
return f"std::pow({a}, {b})"

torchinductor/codegen/triton.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand
159159
def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op
160160
return f"tl.randn({seed}, {offset})"
161161

162+
@staticmethod
163+
def rsqrt(x):
164+
return f"tl.libdevice.rsqrt({x})"
165+
162166
@staticmethod
163167
def pow(a, b):
164168
return f"tl.libdevice.pow({a}, {b})"

torchinductor/decomposition.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
aten.tanh_backward,
8383
aten.threshold_backward,
8484
aten.transpose.int,
85+
aten.tril.default,
8586
aten.upsample_nearest2d_backward,
8687
aten.upsample_bilinear2d.vec,
8788
]
@@ -125,11 +126,6 @@ def tanh(x):
125126
return 2.0 / (1.0 + torch.exp(-2.0 * x)) - 1.0
126127

127128

128-
@register_decomposition([aten.rsqrt])
129-
def rsqrt(x):
130-
return torch.reciprocal(torch.sqrt(x))
131-
132-
133129
@register_decomposition([aten.log2])
134130
def log2(x):
135131
return torch.log(x) * (1.0 / math.log(2.0))

torchinductor/lowering.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,28 @@ def accumulate(out_x, out_y, index_range1, index_range2=None):
18581858
)
18591859

18601860

1861+
@register_lowering(prims.rev.default)
1862+
def rev(x, dims):
1863+
# note - dims pre-canoncalized
1864+
x_loader = x.make_loader()
1865+
sizes = x.get_size()
1866+
1867+
def loader(idx):
1868+
idx = list(idx)
1869+
assert len(idx) == len(sizes)
1870+
for dim in dims:
1871+
idx[dim] = (sizes[dim] - 1) - idx[dim]
1872+
1873+
return x_loader(idx)
1874+
1875+
return Pointwise.create(
1876+
device=x.get_device(),
1877+
dtype=x.get_dtype(),
1878+
inner_fn=loader,
1879+
ranges=sizes,
1880+
)
1881+
1882+
18611883
@register_lowering(aten.constant_pad_nd, type_promote=False)
18621884
def constant_pad_nd(x, padding, fill_value=0):
18631885
assert (len(padding) % 2) == 0
@@ -2748,6 +2770,7 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None):
27482770
register_pointwise(aten.reciprocal)
27492771
register_pointwise(aten.remainder)
27502772
register_pointwise(aten.round)
2773+
register_pointwise(aten.rsqrt)
27512774
register_pointwise(aten.sign)
27522775
register_pointwise(aten.silu)
27532776
register_pointwise(aten.ceil)

0 commit comments

Comments
 (0)