Open
Description
In order to support sub-byte dtypes for quantization, I (and many others) believe that it is better to pack these smaller dtypes into existing pytorch dtypes in order to reduce memory bandwidth contention for a bit of increased computation. Here is a preliminary algorithm in pytorch for doing this. It supports many types of conversions as seen in the tests.
Inspecting the compiled Triton code seems promising because it only launches one kernel and one buffer. Here is a snippit
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 4
x1 = (xindex // 4)
x2 = xindex
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 1, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x1), tmp4 & xmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.full([1], 6, tl.uint8)
tmp7 = tmp5 >> tmp6
tmp8 = tl.full([1], 3, tl.uint8)
tmp9 = tmp7 & tmp8
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp4, tmp9, tmp10)
tmp12 = tmp0 >= tmp3
tmp13 = tl.full([1], 2, tl.int64)
tmp14 = tmp0 < tmp13
tmp15 = tmp12 & tmp14
tmp16 = tl.load(in_ptr0 + (x1), tmp15 & xmask, eviction_policy='evict_last', other=0.0)
tmp17 = tl.full([1], 4, tl.uint8)
tmp18 = tmp16 >> tmp17
tmp19 = tmp18 & tmp8
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp15, tmp19, tmp20)
tmp22 = tmp0 >= tmp13
tmp23 = tl.full([1], 3, tl.int64)
tmp24 = tmp0 < tmp23
tmp25 = tmp22 & tmp24
tmp26 = tl.load(in_ptr0 + (x1), tmp25 & xmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.full([1], 2, tl.uint8)
tmp28 = tmp26 >> tmp27
tmp29 = tmp28 & tmp8
tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
tmp31 = tl.where(tmp25, tmp29, tmp30)
tmp32 = tmp0 >= tmp23
tmp33 = tl.full([1], 4, tl.int64)
tmp34 = tmp0 < tmp33
tmp35 = tl.load(in_ptr0 + (x1), tmp32 & xmask, eviction_policy='evict_last', other=0.0)
tmp36 = tl.full([1], 0, tl.uint8)
tmp37 = tmp35 >> tmp36
tmp38 = tmp37 & tmp8
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp32, tmp38, tmp39)
tmp41 = tl.where(tmp25, tmp31, tmp40)
tmp42 = tl.where(tmp15, tmp21, tmp41)
tmp43 = tl.where(tmp4, tmp11, tmp42)
tl.store(out_ptr0 + (x2), tmp43, xmask)
''', device_str='cuda')
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
s0 = arg0_1
s1 = arg1_1
s2 = arg2_1
assert_size_stride(arg3_1, (s0, s1, s2), (s1*s2, s2, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((s0, s1, s2, 4), (4*s1*s2, 4*s2, 4, 1), torch.uint8)
# Source Nodes: [stack], Original ATen: [aten.stack]
triton_poi_fused_stack_0_xnumel = 4*s0*s1*s2
stream0 = get_raw_stream(0)
triton_poi_fused_stack_0.run(arg3_1, buf0, triton_poi_fused_stack_0_xnumel, grid=grid(triton_poi_fused_stack_0_xnumel), stream=stream0)
del arg3_1
return (reinterpret_tensor(buf0, (s0, s1, 4*s2), (4*s1*s2, 4*s2, 1), 0), )