|
| 1 | +import copy |
| 2 | +from functools import partial |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +from torch import nn |
| 7 | +from torch.testing._internal.common_utils import ( |
| 8 | + TestCase, |
| 9 | + instantiate_parametrized_tests, |
| 10 | + parametrize, |
| 11 | + run_tests, |
| 12 | +) |
| 13 | +from torchao.prototype import low_bit_optim |
| 14 | +from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit |
| 15 | +from torchao.utils import TORCH_VERSION_AFTER_2_3 |
| 16 | + |
| 17 | +try: |
| 18 | + import bitsandbytes as bnb |
| 19 | +except ImportError: |
| 20 | + bnb = None |
| 21 | + |
| 22 | +try: |
| 23 | + import lpmm |
| 24 | +except ImportError: |
| 25 | + lpmm = None |
| 26 | + |
| 27 | + |
| 28 | +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) |
| 29 | + |
| 30 | + |
| 31 | +class TestQuantize(TestCase): |
| 32 | + @parametrize("device", _DEVICES) |
| 33 | + def test_quantize_8bit_with_qmap_correctness(self, device): |
| 34 | + x = torch.randn(32, 1024, device=device) |
| 35 | + qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) |
| 36 | + |
| 37 | + actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1) |
| 38 | + expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0) |
| 39 | + |
| 40 | + torch.testing.assert_close(actual_codes, expected_codes) |
| 41 | + torch.testing.assert_close(actual_scale, expected_scale) |
| 42 | + |
| 43 | + @parametrize("device", _DEVICES) |
| 44 | + def test_quantize_8bit_with_qmap_compile(self, device): |
| 45 | + x = torch.randn(32, 1024, device=device) |
| 46 | + qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) |
| 47 | + |
| 48 | + compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True) |
| 49 | + actual_codes, actual_scale = compiled_f(x, qmap, 256) |
| 50 | + expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256) |
| 51 | + |
| 52 | + torch.testing.assert_close(actual_codes, expected_codes) |
| 53 | + torch.testing.assert_close(actual_scale, expected_scale) |
| 54 | + |
| 55 | + @parametrize("device", _DEVICES) |
| 56 | + def test_quantize_4bit_with_qmap_correctness(self, device): |
| 57 | + x = torch.randn(32, 1024, device=device) |
| 58 | + qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) |
| 59 | + |
| 60 | + actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1) |
| 61 | + expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0) |
| 62 | + |
| 63 | + torch.testing.assert_close(actual_codes, expected_codes) |
| 64 | + torch.testing.assert_close(actual_scale, expected_scale) |
| 65 | + |
| 66 | + @parametrize("device", _DEVICES) |
| 67 | + def test_quantize_4bit_with_qmap_compile(self, device): |
| 68 | + x = torch.randn(32, 1024, device=device) |
| 69 | + qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) |
| 70 | + |
| 71 | + compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True) |
| 72 | + actual_codes, actual_scale = compiled_f(x, qmap, 256) |
| 73 | + expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256) |
| 74 | + |
| 75 | + torch.testing.assert_close(actual_codes, expected_codes) |
| 76 | + torch.testing.assert_close(actual_scale, expected_scale) |
| 77 | + |
| 78 | + |
| 79 | +class TestOptim(TestCase): |
| 80 | + @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") |
| 81 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") |
| 82 | + @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") |
| 83 | + @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) |
| 84 | + def test_optim_8bit_correctness(self, optim_name): |
| 85 | + device = "cuda" |
| 86 | + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) |
| 87 | + model2 = copy.deepcopy(model1) |
| 88 | + |
| 89 | + optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) |
| 90 | + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) |
| 91 | + |
| 92 | + for _ in range(2): |
| 93 | + x = torch.randn(4, 32, device=device) |
| 94 | + |
| 95 | + loss1 = model1(x).sum() |
| 96 | + loss1.backward() |
| 97 | + optim1.step() |
| 98 | + optim1.zero_grad() |
| 99 | + |
| 100 | + loss2 = model2(x).sum() |
| 101 | + loss2.backward() |
| 102 | + optim2.step() |
| 103 | + optim2.zero_grad() |
| 104 | + |
| 105 | + for p1, p2 in zip(model1.parameters(), model2.parameters()): |
| 106 | + torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) |
| 107 | + |
| 108 | + @pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") |
| 109 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") |
| 110 | + @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") |
| 111 | + @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) |
| 112 | + def test_optim_4bit_correctness(self, optim_name): |
| 113 | + device = "cuda" |
| 114 | + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) |
| 115 | + model2 = copy.deepcopy(model1) |
| 116 | + |
| 117 | + # lpmm doesn't have Adam. use AdamW with no weight decay instead. |
| 118 | + if optim_name == "Adam4bit": |
| 119 | + optim1 = lpmm.optim.AdamW(model1.parameters(), weight_decay=0) |
| 120 | + elif optim_name == "AdamW4bit": |
| 121 | + optim1 = lpmm.optim.AdamW(model1.parameters()) |
| 122 | + else: |
| 123 | + raise ValueError(f"Unsupported {optim_name} optimizer for lpmm") |
| 124 | + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) |
| 125 | + |
| 126 | + for _ in range(2): |
| 127 | + x = torch.randn(4, 32, device=device) |
| 128 | + |
| 129 | + loss1 = model1(x).sum() |
| 130 | + loss1.backward() |
| 131 | + optim1.step() |
| 132 | + optim1.zero_grad() |
| 133 | + |
| 134 | + loss2 = model2(x).sum() |
| 135 | + loss2.backward() |
| 136 | + optim2.step() |
| 137 | + optim2.zero_grad() |
| 138 | + |
| 139 | + for p1, p2 in zip(model1.parameters(), model2.parameters()): |
| 140 | + torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) |
| 141 | + |
| 142 | + |
| 143 | +instantiate_parametrized_tests(TestQuantize) |
| 144 | +instantiate_parametrized_tests(TestOptim) |
| 145 | + |
| 146 | + |
| 147 | +if __name__ == "__main__": |
| 148 | + run_tests() |
0 commit comments