Skip to content

BF16 support for Quant-LLM kernel #1147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f50d8d7
Add FP6 benchmark option to use BF16
tobiasvanderwerff Oct 10, 2024
a714377
Change dequant bit-shifting logic for BF16
tobiasvanderwerff Oct 11, 2024
5af3b7e
Modify dequant + tensor core ops for bf16
tobiasvanderwerff Oct 14, 2024
125f17c
Template progress
tobiasvanderwerff Oct 14, 2024
b3c3be0
Modify fpx quant logic to include bf16
tobiasvanderwerff Oct 18, 2024
f828763
Add tests for FP6 BF16
tobiasvanderwerff Oct 18, 2024
ff2c6e8
Use type punning for large exponent multiplication
tobiasvanderwerff Oct 22, 2024
4304dcc
Fix some TODOs
tobiasvanderwerff Oct 23, 2024
2d00a3a
Remove option to add exponent bias directly to the exponent bits
tobiasvanderwerff Oct 23, 2024
ceaed34
Reformat
tobiasvanderwerff Oct 23, 2024
b532c51
Cleanup
tobiasvanderwerff Oct 23, 2024
e89274b
Fix alignment
tobiasvanderwerff Oct 24, 2024
ac0fbe0
Remove templated input type whenever possible
tobiasvanderwerff Oct 24, 2024
c1dce42
Remove templated input type whenever possible 2
tobiasvanderwerff Oct 24, 2024
4546c8b
Remove templated input type whenever possible 3
tobiasvanderwerff Oct 24, 2024
bba42cf
Less hacky way to construct a float with a large exponent
tobiasvanderwerff Oct 24, 2024
e66395e
rtol=1e-2 instead of 1e-3 for bfloat16 test
tobiasvanderwerff Oct 24, 2024
7e9350e
Guards for SM75
tobiasvanderwerff Oct 24, 2024
401559f
Remove redundant `__CUDA_ARCH` guards in host code
tobiasvanderwerff Oct 29, 2024
5d52e5b
Fix consistency in checking for `CUDA_ARCH` versions
tobiasvanderwerff Oct 29, 2024
398da5b
Update docs
tobiasvanderwerff Oct 29, 2024
d38490f
Make float bias a constexpr
tobiasvanderwerff Oct 30, 2024
11ac84b
Update docs more
tobiasvanderwerff Oct 30, 2024
7bd2833
Fix SM75 support
tobiasvanderwerff Oct 30, 2024
69e901d
Compile guard for sm<75
tobiasvanderwerff Oct 31, 2024
8747d6d
Check for CUDA synchronous errors after kernel launch
tobiasvanderwerff Oct 31, 2024
59f5eb7
Updated compile guard
tobiasvanderwerff Oct 31, 2024
c96cf18
Fix problematic usage of `__CUDA_ARCH__`
tobiasvanderwerff Nov 1, 2024
379bd5e
Fix incorrect CUDA error handling
tobiasvanderwerff Nov 1, 2024
a6de35a
Make the kernel fail for sm75 + bfloat16 inputs
tobiasvanderwerff Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,42 @@


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
fp6_output = F.linear(fp16_act, fp6_weight)
float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda")
float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2))
fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2))
fp16_weight = fp6_weight_fp16.dequantize(torch.float16)
bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16)

fp16_act = torch.randn(m, k, dtype=torch.float16, device="cuda")
bf16_act = fp16_act.to(torch.bfloat16)
fp6_output_fp16 = F.linear(fp16_act, fp6_weight_fp16)
fp6_output_bf16 = F.linear(bf16_act, fp6_weight_bf16)
fp16_output = F.linear(fp16_act, fp16_weight)
bf16_output = F.linear(bf16_act, bf16_weight)

fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight)
fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16)
fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16)

# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
# doesn't seem to be the right way to check for correctness
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. I saw that generally when BF16 is used, tolerance is quite higher than FP16. From your experience working on this, you do suspect any part of the code might result in this loss of precision? e.g. perhaps some parts are computed in BF16 instead of FP32. Or maybe it's just the way it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All I know is that BF16 has fewer bits for the fraction (mantissa) than FP16 (10 bits vs. 7 bits), so that leads to lower precision for BF16 compared to FP16. I can't think of any part of the FP6 kernel that would inherently lead to more loss of precision for BF16.


return {
"m": m,
"k": k,
"n": n,
"fp6_latency (ms)": fp6_time,
"fp16_latency (ms)": fp16_time,
"speedup (d/s)": fp16_time / fp6_time,
"correct": correct,
"fp6-fp16 latency (ms)": fp6_time_fp16,
"fp16 latency (ms)": fp16_time,
"speedup fp16": fp16_time / fp6_time_fp16,
"correct fp16": correct_fp16,
"fp6-bf16 latency (ms)": fp6_time_bf16,
"bf16 latency (ms)": bf16_time,
"speedup bf16": bf16_time / fp6_time_bf16,
"correct bf16": correct_bf16,
}


Expand Down
7 changes: 4 additions & 3 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,17 @@ def test_to_copy_device(self, ebits, mbits):
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.skipif(is_fbcode(), reason="broken in fbcode")
def test_fpx_weight_only(self, ebits, mbits, bias):
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half)
linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fpx_weight_only(ebits, mbits))

x = torch.randn(N, IC, device=device, dtype=torch.half)
x = torch.randn(N, IC, device=device, dtype=dtype)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
# somehow compile now changes the result a bit
Expand Down
21 changes: 12 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,23 @@


class TestOps(TestCase):
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
scale = torch.rand(OC).half() + 0.5
fp16_act = torch.rand(BS, IC).half() + 0.5
scale = torch.rand(OC).to(dtype) + 0.5
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear(self, ebits, mbits):
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear(self, ebits, mbits, dtype):
BS = 2
OC = 256
IC = 256
splitK = 1
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype)

# smoke test
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
Expand All @@ -60,19 +61,21 @@ def test_quant_llm_linear(self, ebits, mbits):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype):
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype)

results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)

fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half()
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype)
results_fp16 = fp16_act @ fp16_weight.T

error = (results_floatx - results_fp16).abs().mean()
gt = results_fp16.abs().mean()
relative_error = error / gt
assert relative_error < 1e-3
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
assert relative_error < rtol

instantiate_parametrized_tests(TestOps)

Expand Down
4 changes: 2 additions & 2 deletions torchao/csrc/cuda/fp6_llm/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# FP6-LLM kernel

This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN).
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN).

On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion.

See https://github.com/pytorch/ao/pull/223 for some benchmark results.
See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results.
Loading
Loading