Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

desertfire
Copy link
Contributor

Summary: for #918

@desertfire
Copy link
Contributor Author

desertfire commented Aug 22, 2022

BEFORE: tmp4 = tl.load(in_ptr1 + 0, tmp2) is incorrectly promoted out of the reduction loop.

@reduction_heuristics(size_hints=[32, 8192])
@triton.jit
def kernel0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
    xmask = xindex < xnumel
    rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
    x0 = xindex
    tmp4 = tl.load(in_ptr1 + 0, tmp2)
    _tmp12 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1 + (7823*x0)
        tmp1 = 140800
        tmp2 = tmp0 < tmp1
        tmp3 = tl.load(in_ptr0 + (r1 + (7823*x0)) % 140800 + tl.zeros([XBLOCK, RBLOCK], tl.int32), xmask & rmask & tmp2, eviction_policy='evict_last')
        tmp5 = tmp3 / tmp4
        tmp6 = tl.load(in_ptr2 + (r1 + (7823*x0)) % 140800 + tl.zeros([XBLOCK, RBLOCK], tl.int32), xmask & rmask & tmp2, eviction_policy='evict_last')
        tmp7 = tmp5 * tmp6
        tmp8 = 0
        tmp9 = tmp2 | tl.zeros(tmp7.shape, tmp2.dtype) if tmp7.numel > 1 else tmp2
        tmp10 = tmp9 | tl.zeros(tmp8.shape, tmp9.dtype) if tmp8.numel > 1 else tmp9
        tmp11 = tl.where(tmp10, tmp7, tmp8)
        _tmp12 = tl.where(xmask & rmask, _tmp12 + tmp11, _tmp12)
    tmp12 = tl.reshape(tl.sum(_tmp12, 1), [XBLOCK, 1])
    tl.store(out_ptr0 + x0, tmp12, xmask)

AFTER:

@reduction_heuristics(size_hints=[32, 8192])
@triton.jit
def kernel0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
    xmask = xindex < xnumel
    rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
    x0 = xindex
    _tmp12 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1 + (7823*x0)
        tmp1 = 140800
        tmp2 = tmp0 < tmp1
        tmp3 = tl.load(in_ptr0 + (r1 + (7823*x0)) % 140800 + tl.zeros([XBLOCK, RBLOCK], tl.int32), xmask & rmask & tmp2, eviction_policy='evict_last')
        tmp4 = tl.load(in_ptr1 + 0, tmp2)
        tmp5 = tmp3 / tmp4
        tmp6 = tl.load(in_ptr2 + (r1 + (7823*x0)) % 140800 + tl.zeros([XBLOCK, RBLOCK], tl.int32), xmask & rmask & tmp2, eviction_policy='evict_last')
        tmp7 = tmp5 * tmp6
        tmp8 = 0
        tmp9 = tmp2 | tl.zeros(tmp7.shape, tmp2.dtype) if tmp7.numel > 1 else tmp2
        tmp10 = tmp9 | tl.zeros(tmp8.shape, tmp9.dtype) if tmp8.numel > 1 else tmp9
        tmp11 = tl.where(tmp10, tmp7, tmp8)
        _tmp12 = tl.where(xmask & rmask, _tmp12 + tmp11, _tmp12)
    tmp12 = tl.reshape(tl.sum(_tmp12, 1), [XBLOCK, 1])
    tl.store(out_ptr0 + x0, tmp12, xmask)

@desertfire
Copy link
Contributor Author

There is still another error where tl.load fails for a dim=0 tensor,

tmp4 =  tl.load(in_ptr1 + 0, tmp2)
...
arg1_1 = rand_strided((), (), device='cuda', dtype=torch.float32)

@desertfire desertfire force-pushed the binbao/torchinductor_3 branch from b089fb0 to 777059f Compare August 23, 2022 03:48
@@ -632,6 +638,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
or indirect_indexing
or self._load_mask is not None
) and index != 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index != 0 here seems related, I think this was related to loading the random mask.

@@ -646,7 +653,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
mask.append(f"{tree.prefix}mask")
dense_mask.append(f"{tree.prefix}mask")

if need_dense and not have_dense:
if (need_dense and not have_dense) or load_index0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not

Suggested change
if (need_dense and not have_dense) or load_index0:
if (need_dense and not have_dense) or index == 0:

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if this is safe to do if we are not calling from load, but CI run suggests this is ok.

@desertfire desertfire force-pushed the binbao/torchinductor_3 branch from 777059f to de3925a Compare August 23, 2022 12:51
@desertfire desertfire marked this pull request as ready for review August 23, 2022 13:23
@desertfire desertfire merged commit 6c37dfe into main Aug 23, 2022
@ngimel
Copy link

ngimel commented Aug 23, 2022

There is still another error where tl.load fails for a dim=0 tensor,

wdym? How does it fail?

@desertfire
Copy link
Contributor Author

There is still another error where tl.load fails for a dim=0 tensor,

wdym? How does it fail?

The error I saw was,

>       name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs)
E       MemoryError: std::bad_alloc

../triton/python/triton/code_gen.py:1320: MemoryError

Adding tl.zeros([XBLOCK, RBLOCK] to the indexing would make it pass, but as @ngimel pointed out in #981, what we really need here is mask as None.

desertfire added a commit that referenced this pull request Aug 24, 2022
Summary: Change how tl.load with index 0 is done
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants