Skip to content

Generic packing algorithms from size N to M #284

Open
@vayuda

Description

@vayuda

In order to support sub-byte dtypes for quantization, I (and many others) believe that it is better to pack these smaller dtypes into existing pytorch dtypes in order to reduce memory bandwidth contention for a bit of increased computation. Here is a preliminary algorithm in pytorch for doing this. It supports many types of conversions as seen in the tests.

Inspecting the compiled Triton code seems promising because it only launches one kernel and one buffer. Here is a snippit

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 4
    x1 = (xindex // 4)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (x1), tmp4 & xmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.full([1], 6, tl.uint8)
    tmp7 = tmp5 >> tmp6
    tmp8 = tl.full([1], 3, tl.uint8)
    tmp9 = tmp7 & tmp8
    tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
    tmp11 = tl.where(tmp4, tmp9, tmp10)
    tmp12 = tmp0 >= tmp3
    tmp13 = tl.full([1], 2, tl.int64)
    tmp14 = tmp0 < tmp13
    tmp15 = tmp12 & tmp14
    tmp16 = tl.load(in_ptr0 + (x1), tmp15 & xmask, eviction_policy='evict_last', other=0.0)
    tmp17 = tl.full([1], 4, tl.uint8)
    tmp18 = tmp16 >> tmp17
    tmp19 = tmp18 & tmp8
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tmp0 >= tmp13
    tmp23 = tl.full([1], 3, tl.int64)
    tmp24 = tmp0 < tmp23
    tmp25 = tmp22 & tmp24
    tmp26 = tl.load(in_ptr0 + (x1), tmp25 & xmask, eviction_policy='evict_last', other=0.0)
    tmp27 = tl.full([1], 2, tl.uint8)
    tmp28 = tmp26 >> tmp27
    tmp29 = tmp28 & tmp8
    tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
    tmp31 = tl.where(tmp25, tmp29, tmp30)
    tmp32 = tmp0 >= tmp23
    tmp33 = tl.full([1], 4, tl.int64)
    tmp34 = tmp0 < tmp33
    tmp35 = tl.load(in_ptr0 + (x1), tmp32 & xmask, eviction_policy='evict_last', other=0.0)
    tmp36 = tl.full([1], 0, tl.uint8)
    tmp37 = tmp35 >> tmp36
    tmp38 = tmp37 & tmp8
    tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
    tmp40 = tl.where(tmp32, tmp38, tmp39)
    tmp41 = tl.where(tmp25, tmp31, tmp40)
    tmp42 = tl.where(tmp15, tmp21, tmp41)
    tmp43 = tl.where(tmp4, tmp11, tmp42)
    tl.store(out_ptr0 + (x2), tmp43, xmask)
''', device_str='cuda')

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1 = args
    args.clear()
    s0 = arg0_1
    s1 = arg1_1
    s2 = arg2_1
    assert_size_stride(arg3_1, (s0, s1, s2), (s1*s2, s2, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((s0, s1, s2, 4), (4*s1*s2, 4*s2, 4, 1), torch.uint8)
        # Source Nodes: [stack], Original ATen: [aten.stack]
        triton_poi_fused_stack_0_xnumel = 4*s0*s1*s2
        stream0 = get_raw_stream(0)
        triton_poi_fused_stack_0.run(arg3_1, buf0, triton_poi_fused_stack_0_xnumel, grid=grid(triton_poi_fused_stack_0_xnumel), stream=stream0)
        del arg3_1
    return (reinterpret_tensor(buf0, (s0, s1, 4*s2), (4*s1*s2, 4*s2, 1), 0), )

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions