Skip to content

Commit bc2f8b7

Browse files
authored
Generalize Model Size Code (#364)
* Generalize Model Size Code Summary: previously this worked only on model swap quantized classes and the version in generate.py was specific to a few cases, this new version is significantly more general and now consolidated into a single place. Test Plan: python test_integration.py -k "test_get_model_size" Reviewers: Subscribers: Tasks: Tags: * fix test failures Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 2b56245 commit bc2f8b7

File tree

3 files changed

+82
-24
lines changed

3 files changed

+82
-24
lines changed

test/integration/test_integration.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,60 @@ def forward(self, x):
13731373
after_export = model(x)
13741374
self.assertTrue(torch.equal(after_export, ref))
13751375

1376+
class TestUtils(unittest.TestCase):
1377+
@parameterized.expand(COMMON_DEVICE_DTYPE)
1378+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
1379+
def test_get_model_size_autoquant(self, device, dtype):
1380+
if device != "cuda" and dtype != torch.bfloat16:
1381+
self.skipTest(f"autoquant currently does not support {device}")
1382+
if device != "cuda" or not torch.cuda.is_available():
1383+
self.skipTest(f"autoquant currently does not support {device}")
1384+
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
1385+
if dtype == torch.bfloat16:
1386+
self.skipTest(f"bfloat16 requires sm80+")
1387+
m, k, n = 16, 128, 128
1388+
model = torch.nn.Sequential(
1389+
torch.nn.ReLU(),
1390+
torch.nn.Linear(k,n),
1391+
torch.nn.ReLU(),
1392+
).to(device).to(dtype)
1393+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1394+
size = torchao.utils.get_model_size_in_bytes(model)
1395+
1396+
from torchao.quantization.autoquant import (
1397+
AQWeightOnlyQuantizedLinearWeight2,
1398+
)
1399+
qtensor_class_list = (
1400+
AQWeightOnlyQuantizedLinearWeight2,
1401+
1402+
)
1403+
1404+
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list)
1405+
mod(example_input)
1406+
size2 = torchao.utils.get_model_size_in_bytes(mod)
1407+
self.assertTrue(size2 < size)
1408+
1409+
@parameterized.expand(
1410+
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
1411+
)
1412+
def test_get_model_size_aqt(self, api, test_device, test_dtype):
1413+
if test_dtype != torch.bfloat16:
1414+
self.skipTest(f"{api} in {test_dtype} is not supported yet")
1415+
if test_device != "cuda" or not torch.cuda.is_available():
1416+
self.skipTest(f"{api} currently does not support {test_device}")
1417+
k, n = 1024, 1024
1418+
model = torch.nn.Sequential(
1419+
torch.nn.ReLU(),
1420+
torch.nn.Linear(k,n),
1421+
torch.nn.ReLU(),
1422+
).to(test_device).to(test_dtype)
1423+
size = torchao.utils.get_model_size_in_bytes(model)
1424+
api(model)
1425+
size2 = torchao.utils.get_model_size_in_bytes(model)
1426+
self.assertTrue(size2 < size)
1427+
1428+
1429+
13761430

13771431
if __name__ == "__main__":
13781432
unittest.main()

torchao/_models/llama/generate.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torchao
1414
import torch._dynamo.config
1515
import torch._inductor.config
16+
from torchao.utils import get_model_size_in_bytes
1617

1718
def device_sync(device):
1819
if "cuda" in device:
@@ -143,21 +144,6 @@ def _load_model(checkpoint_path, device, precision):
143144

144145
return model.eval()
145146

146-
def _get_model_size(model):
147-
model_size = 0
148-
for name, child in model.named_children():
149-
if not isinstance(child, torch.nn.Embedding):
150-
for p in itertools.chain(child.parameters(), child.buffers()):
151-
# handling for tensor subclasses
152-
if isinstance(p, torchao.dtypes.aqt.AffineQuantizedTensor):
153-
layout_tensor = p.layout_tensor
154-
for attr_name in layout_tensor._tensor_flatten__()[0]:
155-
sub_tensor = getattr(layout_tensor, attr_name)
156-
model_size += sub_tensor.numel() * sub_tensor.element_size()
157-
else:
158-
model_size += p.numel() * p.element_size()
159-
return model_size
160-
161147
B_INST, E_INST = "[INST]", "[/INST]"
162148

163149
def main(
@@ -226,7 +212,7 @@ def main(
226212
interactive=False
227213
)
228214

229-
model_size = _get_model_size(model) / 1e9
215+
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
230216

231217
if compile:
232218
global decode_one_token, prefill

torchao/utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from math import gcd
66
from packaging import version
77
import torch.nn.utils.parametrize as parametrize
8+
import itertools
89

910
__all__ = [
1011
"benchmark_model",
@@ -82,14 +83,31 @@ def find_multiple(n: int, *args: Tuple[int]) -> int:
8283
return n
8384
return n + k - (n % k)
8485

85-
# https://discuss.pytorch.org/t/finding-model-size/130275
86-
def get_model_size_in_bytes(model):
87-
s = 0
88-
for p in model.parameters():
89-
s += p.nelement() * p.element_size()
90-
for b in model.buffers():
91-
s += b.nelement() * b.element_size()
92-
return s
86+
def get_model_size_in_bytes(model, ignore_embeddings=False):
87+
"""
88+
Returns the model size in bytes. The option to ignore embeddings
89+
is useful for models with disproportionately large embeddings compared
90+
to other model parameters that get quantized/sparsified.
91+
"""
92+
def flat_size(tensor):
93+
if hasattr(tensor, "__tensor_flatten__"):
94+
size = 0
95+
# 0th element is a list of attributes that
96+
# hold tensors
97+
for attr_name in tensor.__tensor_flatten__()[0]:
98+
sub_tensor = getattr(tensor, attr_name)
99+
size += flat_size(sub_tensor)
100+
return size
101+
else:
102+
return tensor.numel() * tensor.element_size()
103+
104+
model_size = 0
105+
for name, child in model.named_children():
106+
if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings):
107+
for p in itertools.chain(child.parameters(recurse=False), child.buffers(recurse=False)):
108+
model_size += flat_size(p)
109+
model_size += get_model_size_in_bytes(child, ignore_embeddings)
110+
return model_size
93111

94112
class UnwrapTensorSubclass(torch.nn.Module):
95113
def forward(self, *tensors):

0 commit comments

Comments
 (0)