From 9cac085d3abd951226af90f79d2b02cb434a6552 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 10 Mar 2025 09:59:53 -0700 Subject: [PATCH 1/2] Remove prototype profiler --- test/prototype/test_device_spec.py | 70 -- test/prototype/test_performance_counter.py | 528 ---------------- test/prototype/utils.py | 103 --- torchao/_models/llama/perf_profile.py | 442 ------------- torchao/prototype/profiler/__init__.py | 21 - torchao/prototype/profiler/device_spec.py | 421 ------------ .../prototype/profiler/performance_counter.py | 597 ------------------ torchao/prototype/profiler/utils.py | 44 -- 8 files changed, 2226 deletions(-) delete mode 100644 test/prototype/test_device_spec.py delete mode 100644 test/prototype/test_performance_counter.py delete mode 100644 test/prototype/utils.py delete mode 100644 torchao/_models/llama/perf_profile.py delete mode 100644 torchao/prototype/profiler/__init__.py delete mode 100644 torchao/prototype/profiler/device_spec.py delete mode 100644 torchao/prototype/profiler/performance_counter.py delete mode 100644 torchao/prototype/profiler/utils.py diff --git a/test/prototype/test_device_spec.py b/test/prototype/test_device_spec.py deleted file mode 100644 index dd159f5336..0000000000 --- a/test/prototype/test_device_spec.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest - -cuda_driver = pytest.importorskip( - "triton.runtime.driver", reason="requires triton cuda driver module" -) -import itertools - -import torch -from utils import patch_device - -from torchao.prototype.profiler.device_spec import ( - _AVAILABLE_GPU_SPECS, - CUDADeviceSpec, - get_chip_name, -) - -# -------------------- Device Spec Tests ------------------- # -DEVICE_NAMES = ["h100 sxm", "a100", "nvidia geforce rtx 4090"] -DTYPES = [torch.float32, torch.bfloat16, torch.float16] -USE_TENSORCORES = [True, False] -DEVICE_CONFIGS = itertools.product(DEVICE_NAMES, DTYPES, USE_TENSORCORES) - - -@pytest.mark.parametrize( - "device_name, dtype, use_tensorcores", DEVICE_CONFIGS, ids=lambda x: str(x) -) -def test_device_spec(device_name, dtype, use_tensorcores): - with patch_device(device_name): - device_spec = CUDADeviceSpec(dtype=dtype, use_tensorcores=use_tensorcores) - if dtype == torch.float32 and use_tensorcores: - dtype = "tfloat32" - chip_name = get_chip_name(device_name) - expected_flops = _AVAILABLE_GPU_SPECS[chip_name][dtype] - assert device_spec.flops_per_s == expected_flops - assert device_spec.flops_by_dtype[dtype] == expected_flops - assert ( - device_spec.roofline_balancepoint == expected_flops / device_spec.bandwidth - ) - - with pytest.raises(AssertionError): - device_spec.flops_per_s = None - print(device_spec.roofline_balancepoint) - # Prevent setting attributes not in named fields to guard against user error - with pytest.raises(AttributeError): - device_spec.FLOPs = None - - -def test_empty_device_spec(): - device_name = "fake device" - with patch_device(device_name): - with pytest.raises(AssertionError): - _ = CUDADeviceSpec() - - # Ok to instantiate as long as fields are filled - _ = CUDADeviceSpec( - name=device_name, - flops_per_s=1.0, - bandwidth=1.0, - dtype=torch.float32, - use_tensorcores=True, - ) - device_name = DEVICE_NAMES[0] - - with patch_device(device_name): - # All critical fields will be auto-filled except for dtype (and vram, but vram is not used for downstream calcs atm) - _ = CUDADeviceSpec(dtype=torch.float32) - - # No dtype specified - with pytest.raises(AssertionError): - _ = CUDADeviceSpec() diff --git a/test/prototype/test_performance_counter.py b/test/prototype/test_performance_counter.py deleted file mode 100644 index 1659cff53b..0000000000 --- a/test/prototype/test_performance_counter.py +++ /dev/null @@ -1,528 +0,0 @@ -import pytest - -# Skip if transformers is not installed -transformers = pytest.importorskip("transformers") -LlamaConfig = transformers.models.llama.modeling_llama.LlamaConfig -LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM - -import json -import tempfile -import time -import unittest -from dataclasses import asdict -from pathlib import Path -from typing import Union - -import torch -from parameterized import parameterized_class -from utils import ( - PerfCounterManagerTestConfig, - PerfCounterResult, - PerfCounterTestConfig, - PerfStatsTestConfig, - attn_io_check, - ffn_io_check, - get_leaf_nodes, - get_test_name, - patch_device, - qkv_proj_io_check, -) - -from torchao.prototype.profiler.device_spec import CUDADeviceSpec, DeviceSpec -from torchao.prototype.profiler.performance_counter import ( - CUDAPerformanceTimer, - PerformanceCounterMode, - PerformanceStats, - PerformanceTimer, - TransformerPerformanceCounter, -) -from torchao.utils import TORCH_VERSION_AFTER_2_5 - -# ------------------- PerformanceCounter Tests ------------------- # - -PERFCOUNTER_TEST_CONFIGS = [ - PerfCounterTestConfig( - name="3.5B", - batch_size=1, - seqlen=128, - dtype=torch.float16, - num_hidden_layers=32 // 2, - hidden_size=4096 // 2, - intermediate_size=11008 // 2, - num_attention_heads=32 // 2, - vocab_size=32000 // 2, - ), - PerfCounterTestConfig( - name="1.25B", - batch_size=1, - seqlen=128, - dtype=torch.float16, - num_hidden_layers=32 // 4, - hidden_size=4096 // 4, - intermediate_size=11008 // 4, - num_attention_heads=32 // 4, - vocab_size=32000 // 4, - ), - PerfCounterTestConfig( - name="tiny", - batch_size=1, - seqlen=128, - dtype=torch.float16, - num_hidden_layers=1, - hidden_size=4096 // 4, - intermediate_size=11008 // 4, - num_attention_heads=32 // 4, - vocab_size=32000 // 4, - ), -] - - -@unittest.skipIf( - not TORCH_VERSION_AFTER_2_5, "PerformanceCounter requires torch >= 2.5+." -) -@unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounter requires CUDA") -@parameterized_class( - [asdict(cfg) for cfg in PERFCOUNTER_TEST_CONFIGS], class_name_func=get_test_name -) -class PerformanceCounterTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - model_cfg = LlamaConfig( - num_hidden_layers=cls.num_hidden_layers, - hidden_size=cls.hidden_size, - intermediate_size=cls.intermediate_size, - num_attention_heads=cls.num_attention_heads, - vocab_size=cls.vocab_size, - ) - - # Note we set some options manually since the model doesn't seem to be initialized correctly - # when these options are set in LlamaConfig - model_cfg._attn_implementation = "sdpa" - cls.model = model = LlamaForCausalLM(model_cfg).to(cls.dtype).to("cuda") - cls.model_config = model.config - cls.element_size = cls.dtype.itemsize - - input_ids = torch.randint( - 0, model.config.vocab_size, (cls.batch_size, cls.seqlen), device="cuda" - ) - with torch.no_grad(): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION - ): - with PerformanceCounterMode() as perf_counter: - _ = model(input_ids) - cls.perf_counter = perf_counter - cls.summary_flops = perf_counter.get_summary_flop_counts() - cls.summary_io = perf_counter.get_summary_io_counts() - cls.flops_by_op = perf_counter.get_flop_counts() - cls.io_by_op = perf_counter.get_io_counts() - - def test_qkv_proj(self): - batch_size, seqlen = self.batch_size, self.seqlen - element_size = self.element_size - - assert len(self.summary_flops) == len(self.summary_io) - assert self.summary_flops.keys() == self.summary_io.keys() - - # Attn Projections - for k in ["q_proj", "k_proj", "v_proj"]: - # Flops check - proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) - assert len(proj_keys) == self.model.config.num_hidden_layers - expected_flops = ( - 2 - * batch_size - * seqlen - * self.model_config.hidden_size - * self.model_config.hidden_size - ) - assert expected_flops == self.summary_flops[proj_keys[0]] - - # io check - expected_size = qkv_proj_io_check( - self.model_config, batch_size, seqlen, element_size - ) - assert expected_size == self.summary_io[proj_keys[0]] - - def test_attn(self): - batch_size, seqlen = self.batch_size, self.seqlen - element_size = self.element_size - model_config = self.model.config - - attention_keys = get_leaf_nodes(self.summary_flops.keys(), "self_attn") - for k in attention_keys: - flops = self.flops_by_op[k] - io_movement = self.io_by_op[k] - for op, count in flops.items(): - if "attention" in op.__name__: - expected_flops = ( - 2 * 2 * batch_size * seqlen * seqlen * model_config.hidden_size - ) - assert expected_flops == count - for op, count in io_movement.items(): - if "attention" in op.__name__: - # Check approx equal due to other small artifacts returned by sdpa.mem_efficient_attention - # See #https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L867 - # Check within 100 bytes - expected_size = attn_io_check( - model_config, batch_size, seqlen, element_size - ) - assert abs(expected_size - count) < 100 - - def test_ffn(self): - batch_size, seqlen = self.batch_size, self.seqlen - element_size = self.element_size - - for k in ["up_proj", "gate_proj", "down_proj"]: - proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) - assert len(proj_keys) == self.model.config.num_hidden_layers - expected_flops = ( - 2 - * batch_size - * seqlen - * self.model_config.hidden_size - * self.model_config.intermediate_size - ) - assert expected_flops == self.summary_flops[proj_keys[0]] - - # io check - expected_size = ffn_io_check( - self.model_config, batch_size, seqlen, element_size, k - ) - assert expected_size == self.summary_io[proj_keys[0]] - - -# ------------------- PerformanceStats Tests ------------------- # - -PERFSTATS_TEST_CONFIGS = [ - PerfStatsTestConfig( - label="with_device", - num_tokens=128, - latency=0.1, - total_flops=123e9, - total_io=123e6, - flops_summary={"a": 234e12, "b": 345e9}, - io_summary={"a": 1, "b": 2}, - flop_counts={"a": 234e12, "b": 345e9}, - io_counts={"a": 1, "b": 2}, - device_bandwidth=1e9, - device_flops_per_s=23e9, - ), - PerfStatsTestConfig( - label="no_device", - num_tokens=128, - latency=0.1, - total_flops=123e9, - total_io=123e6, - flops_summary={"a": 234e12, "b": 345e9}, - io_summary={"a": 1, "b": 2}, - flop_counts={"a": 234e12, "b": 345e9}, - io_counts={"a": 1, "b": 2}, - device_bandwidth=None, - device_flops_per_s=None, - ), -] - - -@pytest.mark.parametrize("cfg", PERFSTATS_TEST_CONFIGS, ids=lambda cfg: cfg.label) -def test_performance_stats(cfg: PerfStatsTestConfig): - stats = PerformanceStats(**asdict(cfg)) - num_tokens = cfg.num_tokens - latency = cfg.latency - total_flops = cfg.total_flops - total_io = cfg.total_io - device_bandwidth = cfg.device_bandwidth - device_flops_per_s = cfg.device_flops_per_s - - # Test derived metrics - assert stats.token_throughput == num_tokens / latency - assert stats.achieved_bandwidth == total_io / latency - assert stats.achieved_flops_per_s == total_flops / latency - if device_bandwidth is not None: - assert ( - stats.bandwidth_utilization == stats.achieved_bandwidth / device_bandwidth - ) - assert stats.theoretical_io_latency == total_io / device_bandwidth - else: - assert stats.bandwidth_utilization is None - assert stats.theoretical_io_latency is None - if device_flops_per_s is not None: - assert ( - stats.flops_utilization == stats.achieved_flops_per_s / device_flops_per_s - ) - assert stats.theoretical_compute_latency == total_flops / device_flops_per_s - else: - assert stats.flops_utilization is None - assert stats.theoretical_compute_latency is None - - # Test str - stats should be formatted to closest power of 10 ** 3 with 2 decimal places of precision - stats_str = str(stats) - - # Base Stats - expected_io_str = ".12 GB" - expected_flops_str = ".12 TFLOPs" - assert expected_io_str in stats_str - assert expected_flops_str in stats_str - - # Derived Stats - expected_io_throughput_str = "1.23 GB/s" - expected_flops_throughput_str = "1.23 TFLOPs/s" - assert expected_io_throughput_str in stats_str - assert expected_flops_throughput_str in stats_str - - # Utilization Stats - if device_bandwidth is not None: - expected_bandwidth_utilization_str = ( - f"{stats.achieved_bandwidth / device_bandwidth:.4f}" - ) - expected_io_latency_str = f"{stats.theoretical_io_latency:.2f} s" - assert expected_bandwidth_utilization_str in stats_str - assert expected_io_latency_str in stats_str - - if device_flops_per_s is not None: - expected_flops_utilization_str = ( - f"{stats.achieved_flops_per_s / device_flops_per_s:.4f}" - ) - expected_compute_latency_str = f"{stats.theoretical_compute_latency:.2f} s" - assert expected_flops_utilization_str in stats_str - assert expected_compute_latency_str in stats_str - - -# ------------------- TransformerPerformanceCounter Tests ------------------- # - -PERFCOUNTERMANAGER_TEST_CONFIGS = [ - PerfCounterManagerTestConfig( - "no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0) - ), - PerfCounterManagerTestConfig( - "a100", - (1, 1024, 4096, 4096), - CUDAPerformanceTimer, - torch.bfloat16, - ("A100", 2e12), - ), -] - - -@unittest.skipIf( - not TORCH_VERSION_AFTER_2_5, "TransformerPerformanceCounter requires torch >= 2.5+." -) -@unittest.skipIf( - not torch.cuda.is_available(), "TransformerPerformanceCounter requires CUDA" -) -@parameterized_class( - [asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], - class_name_func=get_test_name, -) -class TestTransformerPerformanceCounter(unittest.TestCase): - @classmethod - def setUpClass(cls): - shape, timer_cls, dtype = cls.shape, cls.timer_cls, cls.dtype - batch_size, query_len, in_features, out_features = shape - num_tokens = batch_size * query_len - element_size = dtype.itemsize - a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") - b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") - # Set up device spec - device_name, bandwidth = cls.device_spec - if device_name is not None: - with patch_device(device_name): - device_spec = CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) - - else: - device_spec = None - - # Stateful class level objects, which will be used in individual tests - cls.cm = cm = TransformerPerformanceCounter( - timer_cls=timer_cls, device_spec=device_spec - ) - cls.FLOAT_TOL = 1e-5 - cls.expected = expected = {} - - # Start count for a - start = time.perf_counter() - with cm.count("a", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() - - latency = end - start - expected_flops = 2 * num_tokens * in_features * out_features - expected_io = ( - num_tokens * in_features - + in_features * out_features - + num_tokens * out_features - ) * element_size - - expected["a"] = PerfCounterResult( - name="a", - latency=latency, - flops=expected_flops, - io=expected_io, - total_flops=expected_flops, - total_io=expected_io, - ) - - # Start count for b - start = time.perf_counter() - with cm.count("b", num_tokens=num_tokens): - _ = torch.matmul(a, b) - end = time.perf_counter() - latency = end - start - - expected["b"] = PerfCounterResult( - name="b", - latency=latency, - flops=expected_flops, - io=expected_io, - total_flops=cm.total_flops, - total_io=cm.total_io, - ) - - def test_perf_stats_a(self): - cm: TransformerPerformanceCounter = self.cm - expected = self.expected["a"] - - counts = cm.get_counts() - assert "a" in counts - - # Check captured performance stats - psa: PerformanceStats = counts["a"] - # Raw metrics - # Latency won't be exact since timing external to the profiler - assert abs(psa.latency - expected.latency) < 1e-1 # +/- 100ms - assert psa.total_flops == expected.flops - assert psa.total_io == expected.io - - # Derived metrics - assert psa.token_throughput == psa.num_tokens / psa.latency - assert psa.achieved_flops_per_s == psa.total_flops / psa.latency - assert psa.achieved_bandwidth == psa.total_io / psa.latency - - def test_perf_stats_b(self): - cm: TransformerPerformanceCounter = self.cm - assert "a" in cm.counts - assert "b" in cm.counts - psa = cm.counts["a"] - psb = cm.counts["b"] - expected = self.expected["b"] - assert abs(psb.latency - expected.latency) < 1e-1 # +/- 100ms - assert psb.total_flops == expected.flops - assert psb.total_io == expected.io - - # check that **total** flops and io after matmul `b` has run accounts for both matmuls - # also check that these global properties are updated correctly in the manager object - assert ( - expected.total_flops == psa.total_flops + psb.total_flops == cm.total_flops - ) - assert expected.total_io == psa.total_io + psb.total_io == cm.total_io - assert cm.total_time == psa.latency + psb.latency - - def test_stats_summary(self): - cm: TransformerPerformanceCounter = self.cm - FLOAT_TOL = self.FLOAT_TOL - psa = cm.counts["a"] - psb = cm.counts["b"] - summary: PerformanceStats = cm.stats_summary - - # Raw stats - assert summary.num_tokens == psa.num_tokens + psb.num_tokens - assert summary.total_io == psa.total_io + psb.total_io - assert summary.total_flops == psa.total_flops + psb.total_flops - assert summary.latency == psa.latency + psb.latency - - # Derived stats - expected_token_throughput = (psa.num_tokens + psb.num_tokens) / ( - psa.latency + psb.latency - ) - expected_io_throughput = (psa.total_io + psb.total_io) / ( - psa.latency + psb.latency - ) - expected_flops_throughput = (psa.total_flops + psb.total_flops) / ( - psa.latency + psb.latency - ) - assert abs(summary.token_throughput - expected_token_throughput) < FLOAT_TOL - assert abs(summary.achieved_bandwidth - expected_io_throughput) < FLOAT_TOL - assert abs(summary.achieved_flops_per_s - expected_flops_throughput) < FLOAT_TOL - - device_spec = cm.device_spec - if device_spec is not None: - expected_bandwidth_utilization = ( - expected_io_throughput / device_spec.bandwidth - ) - expected_flops_utilization = ( - expected_flops_throughput / device_spec.flops_per_s - ) - assert ( - abs(summary.bandwidth_utilization - expected_bandwidth_utilization) - < FLOAT_TOL - ) - assert ( - abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL - ) - else: - assert summary.bandwidth_utilization is None - assert summary.flops_utilization is None - - def test_json(self): - cm: TransformerPerformanceCounter = self.cm - psa: PerformanceStats = cm.counts["a"] - psb: PerformanceStats = cm.counts["b"] - device_spec: Union[DeviceSpec, None] = cm.device_spec - - with tempfile.TemporaryDirectory() as tmp_dir: - json_path = Path(tmp_dir) / "test.json" - cm.to_json(json_path) - - with open(json_path, "r") as f: - perf_dict = json.load(f) - - assert "a" in perf_dict - assert "b" in perf_dict - - # Test basic stats are recorded properly - assert perf_dict["a"]["num_tokens"] == psa.num_tokens - assert perf_dict["a"]["total_io"] == psa.total_io - assert perf_dict["a"]["total_flops"] == psa.total_flops - assert perf_dict["a"]["latency"] == psa.latency - - assert perf_dict["b"]["num_tokens"] == psb.num_tokens - assert perf_dict["b"]["total_io"] == psb.total_io - assert perf_dict["b"]["total_flops"] == psb.total_flops - assert perf_dict["b"]["latency"] == psb.latency - - # Test derived properties are present - perf_dict["a"]["achieved_flops_per_s"] == psa.achieved_flops_per_s - perf_dict["a"]["achieved_bandwidth"] == psa.achieved_bandwidth - perf_dict["b"]["achieved_flops_per_s"] == psb.achieved_flops_per_s - perf_dict["b"]["achieved_bandwidth"] == psb.achieved_bandwidth - - if device_spec is not None: - assert perf_dict["a"]["device_flops_per_s"] == device_spec.flops_per_s - assert perf_dict["a"]["device_bandwidth"] == device_spec.bandwidth - assert ( - perf_dict["a"]["theoretical_io_latency"] - == psa.theoretical_io_latency - ) - assert ( - perf_dict["a"]["theoretical_compute_latency"] - == psa.theoretical_compute_latency - ) - assert ( - perf_dict["a"]["bandwidth_utilization"] == psa.bandwidth_utilization - ) - assert perf_dict["a"]["flops_utilization"] == psa.flops_utilization - - assert perf_dict["b"]["device_flops_per_s"] == device_spec.flops_per_s - assert perf_dict["b"]["device_bandwidth"] == device_spec.bandwidth - assert ( - perf_dict["b"]["theoretical_io_latency"] - == psb.theoretical_io_latency - ) - assert ( - perf_dict["b"]["theoretical_compute_latency"] - == psb.theoretical_compute_latency - ) - assert ( - perf_dict["b"]["bandwidth_utilization"] == psb.bandwidth_utilization - ) - assert perf_dict["b"]["flops_utilization"] == psb.flops_utilization diff --git a/test/prototype/utils.py b/test/prototype/utils.py deleted file mode 100644 index 8c402b8114..0000000000 --- a/test/prototype/utils.py +++ /dev/null @@ -1,103 +0,0 @@ -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Optional -from unittest.mock import patch - -import torch - -from torchao.prototype.profiler import PerformanceTimer - - -@contextmanager -def patch_device(device_name): - with patch("torch.cuda.get_device_name", return_value=device_name): - yield - - -@dataclass(frozen=True) -class PerfCounterTestConfig: - name: str - batch_size: int - seqlen: int - dtype: torch.dtype - num_hidden_layers: int - hidden_size: int - intermediate_size: int - num_attention_heads: int - vocab_size: int - - -def get_leaf_nodes(count_keys, module_name): - return [k for k in count_keys if k.endswith(module_name)] - - -def qkv_proj_io_check(model_config, batch_size, seqlen, element_size): - input_size = batch_size * seqlen * model_config.hidden_size * element_size - weight_size = model_config.hidden_size * model_config.hidden_size * element_size - output_size = batch_size * seqlen * model_config.hidden_size * element_size - return input_size + weight_size + output_size - - -def attn_io_check(model_config, batch_size, seqlen, element_size): - # queries, keys, values -> factor of 3 - input_size = (batch_size * seqlen * model_config.hidden_size * 3) * element_size - output_size = (batch_size * seqlen * model_config.hidden_size) * element_size - return input_size + output_size - - -def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): - assert module_name in ["up_proj", "gate_proj", "down_proj"] - - if module_name == "down_proj": - input_size = batch_size * seqlen * model_config.intermediate_size * element_size - else: - input_size = batch_size * seqlen * model_config.hidden_size * element_size - weight_size = ( - model_config.hidden_size * model_config.intermediate_size * element_size - ) - if module_name == "down_proj": - output_size = batch_size * seqlen * model_config.hidden_size * element_size - else: - output_size = ( - batch_size * seqlen * model_config.intermediate_size * element_size - ) - - return input_size + weight_size + output_size - - -@dataclass(frozen=True) -class PerfStatsTestConfig: - label: str - num_tokens: int - latency: float - total_flops: float - total_io: float - flops_summary: dict - io_summary: dict - flop_counts: dict - io_counts: dict - device_bandwidth: Optional[float] = None - device_flops_per_s: Optional[float] = None - - -def get_test_name(cls, num, params_dict): - return f"{cls.__name__}_{num}_{params_dict['name']}" - - -@dataclass(frozen=True) -class PerfCounterResult: - name: str - latency: float - flops: float - io: float - total_flops: float - total_io: float - - -@dataclass -class PerfCounterManagerTestConfig: - name: str - shape: tuple[int] - timer_cls: PerformanceTimer - dtype: torch.dtype - device_spec: tuple[Optional[str], int] diff --git a/torchao/_models/llama/perf_profile.py b/torchao/_models/llama/perf_profile.py deleted file mode 100644 index f613982221..0000000000 --- a/torchao/_models/llama/perf_profile.py +++ /dev/null @@ -1,442 +0,0 @@ -""" - -## Performance Profiling Example - -An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.prototype.profiler.TransformerPerformanceCounter`. -- Outputs from gpt-fast are prefixed with GPT-Fast -- Outputs from `torchao.prototype.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. - -## Usage -```python -python perf_profile.py --prompt "Hello my name is" --checkpoint_path path/to/model.pth --num_samples 1 --max_new_tokens 2 --save_path performance_stats.json -``` -where `checkpoint_path` is the checkpoint path of the converted model weights per `gpt-fast` and `save_path` specifies where to save performance stats. - - -Running the above command for `llama2-7b` should print the following, with accumulated stats saved to `performance_stats.json` - -``` -Loading model ... -Time to load model: 20.14 seconds - -============================== - -Using DeviceSpec(device_type=cuda, name=NVIDIA GeForce RTX 3090, dtype=torch.bfloat16, bandwidth=936.1GB/s, flops=35.6TFLOPs, vram=25.4GB) -Model Config: ModelArgs(block_size=2048, vocab_size=32000, n_layer=32, n_head=32, dim=4096, intermediate_size=11008, n_local_heads=32, head_dim=128, rope_base=10000, norm_eps=1e-05) -Active params, Total Params: 6607343616, 6738415616 - -============================== - -TransformerPerfCounter Metrics -PREFILL_SEQLEN-6: - Latency = 1.26 s - Tokens - Total: 6 tokens - Throughput: 5 tokens/s - IO - Total: 13.25 GB - Throughput: 10.54 GB/s - Theoretical Latency: 14.15 ms - FLOPs - Total: 79.31 GFLOPs - Throughput: 63.06 GFLOPs/s - Theoretical Latency: 2.23 ms - Utilization - Bandwidth: 0.0113 % - FLOPs: 0.0018 % - -============================== - -TransformerPerfCounter Metrics -DECODE_CTX-6_NUM_TOKS-1: - Latency = 0.16 s - Tokens - Total: 1 tokens - Throughput: 6 tokens/s - IO - Total: 13.22 GB - Throughput: 83.27 GB/s - Theoretical Latency: 14.13 ms - FLOPs - Total: 13.22 GFLOPs - Throughput: 83.24 GFLOPs/s - Theoretical Latency: 0.37 ms - Utilization - Bandwidth: 0.0890 % - FLOPs: 0.0023 % - -============================== - -Generated text for sample 0: Hello, my name is [Name - -GPTFast Sample Metrics - Time for inference 1: 6 prompt tokens 2 tokens generated, 1.57 sec total, 1.28 tokens/sec - Bandwidth achieved: 17.22 GB/s - -============================== - -GPTFast Aggregate Stats - Average tokens/sec: 1.28 - Memory used: 13.51 GB - -============================== - -TransformerPerfCounter -Performance Summary: - Latency = 1.42 s - Tokens - Total: 7 tokens - Throughput: 5 tokens/s - IO - Total: 26.47 GB - Throughput: 18.69 GB/s - Theoretical Latency: 28.28 ms - FLOPs - Total: 92.53 GFLOPs - Throughput: 65.33 GFLOPs/s - Theoretical Latency: 2.60 ms - Utilization - Bandwidth: 0.0200 % - FLOPs: 0.0018 % - -Saving performance results to performance_stats.json -``` - -**Notes** -- The discrepancy between `gpt-fast` token throughput and that of `TransformerPerformanceCounter` is due to the fact that gpt-fast` only counts generated tokens (no prefill) --- so even though the `prefill` phase technically generates `len(prompt) + 1` tokens, it counts the number of tokens generated during this phase as `1`, -whereas `TransformerPerformanceCounter` includes all `prefill` tokens in the total token count. -""" - -import textwrap -import time -from pathlib import Path -from typing import Optional, Tuple, Union - -import torch -from torch.nn.attention import SDPBackend - -from torchao._models.llama.model import Transformer -from torchao._models.llama.tokenizer import get_tokenizer -from torchao.prototype.profiler import ( - CUDADeviceSpec, - TransformerPerformanceCounter, - total_model_params, -) - -DEVICE_SPEC: CUDADeviceSpec -PERF_COUNTER: TransformerPerformanceCounter -PERF_COUNTER_PREFIX = "TransformerPerfCounter" -GPT_FAST_PREFIX = "GPTFast" -DELIMITER = "\n" + "=" * 30 + "\n" - - -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif ("cpu" in device) or ("mps" in device): - pass - else: - print(f"device={device} is not yet supported") - - -def multinomial_sample_one_no_sync( - probs_sort, -): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs_sort).exponential_(1) - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - - -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - - -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) - idx_next = multinomial_sample_one_no_sync(probs) - return idx_next, probs - - -def prefill( - model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> torch.Tensor: - # input_pos: [B, S] - seqlen = input_pos.shape[-1] - num_tokens = input_pos.numel() - assert num_tokens == seqlen - - step_name = f"prefill_seqlen-{seqlen}".upper() - with PERF_COUNTER.count(step_name, num_tokens=num_tokens): - logits = model(x, input_pos) - next_token = sample(logits, **sampling_kwargs)[0] - print(DELIMITER) - stats_str = PERF_COUNTER.print_summary(labels=[step_name], show=False) - print(f"{PERF_COUNTER_PREFIX} Metrics\n{stats_str}") - - return next_token - - -def decode_one_token( - model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] - context_len = input_pos[-1].item() - num_tokens = input_pos.numel() - assert input_pos.shape[-1] == 1 - assert num_tokens == 1 - - step_name = f"decode_ctx-{context_len}_num_toks-{num_tokens}".upper() - with PERF_COUNTER.count(step_name, num_tokens=num_tokens): - logits = model(x, input_pos) - next_token = sample(logits, **sampling_kwargs) - print(DELIMITER) - stats_str = PERF_COUNTER.print_summary(labels=[step_name], show=False) - print(f"{PERF_COUNTER_PREFIX} Metrics\n{stats_str}") - - return next_token - - -def decode_n_tokens( - model: Transformer, - cur_token: torch.Tensor, - input_pos: torch.Tensor, - num_new_tokens: int, - callback=lambda _: _, - **sampling_kwargs, -): - new_tokens, new_probs = [], [] - for i in range(num_new_tokens): - with torch.nn.attention.sdpa_kernel( - backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH] - ): # Actually better for Inductor to codegen attention here - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - input_pos += 1 - new_tokens.append(next_token.clone()) - callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) - - return new_tokens, new_probs - - -def model_forward(model, x, input_pos): - return model(x, input_pos) - - -@torch.no_grad() -def generate( - model: Transformer, - prompt: torch.Tensor, - max_new_tokens: int, - *, - callback=lambda x: x, - **sampling_kwargs, -) -> torch.Tensor: - """ - Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. - """ - # create an empty tensor of the expected final shape and fill in the current tokens - T = prompt.size(0) - T_new = T + max_new_tokens - max_seq_length = min(T_new, model.config.block_size) - - device, dtype = prompt.device, prompt.dtype - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - next_token = prefill( - model, prompt.view(1, -1), input_pos, **sampling_kwargs - ).clone() - seq[T] = next_token - - input_pos = torch.tensor([T], device=device, dtype=torch.int) - - generated_tokens, _ = decode_n_tokens( - model, - next_token.view(1, -1), - input_pos, - max_new_tokens - 1, - callback=callback, - **sampling_kwargs, - ) - seq[T + 1 :] = torch.cat(generated_tokens) - - return seq - - -def encode_tokens(tokenizer, string, bos=True, device="cuda"): - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) - - -def _load_model(checkpoint_path, device, precision): - with torch.device("meta"): - model = Transformer.from_name(checkpoint_path.parent.name) - - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - model.load_state_dict(checkpoint, assign=True) - - model = model.to(device=device, dtype=precision) - return model.eval() - - -def main( - prompt: str, - num_samples: int, - max_new_tokens: int, - top_k: int, - temperature: float, - checkpoint_path: Union[Path, str], - save_path: Union[Path, str], - device: str = "cuda", - precision: torch.dtype = torch.bfloat16, -) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer.""" - assert checkpoint_path.is_file(), checkpoint_path - - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) - - print(f"{GPT_FAST_PREFIX}") - print("Loading model ...") - t0 = time.time() - model = _load_model(checkpoint_path, device, precision) - - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") - - global DEVICE_SPEC - global PERF_COUNTER - - DEVICE_SPEC = CUDADeviceSpec(dtype=precision) - PERF_COUNTER = TransformerPerformanceCounter(depth=3, device_spec=DEVICE_SPEC) - print(DELIMITER) - print(f"{PERF_COUNTER_PREFIX}") - print(f"Using {DEVICE_SPEC}") - print(f"Model Config: {model.config}") - - num_active_params = total_model_params(model, exclude_embeddings=True) - num_params = total_model_params(model, exclude_embeddings=False) - model_size = num_params * precision.itemsize - print(f"Active params, Total Params: {num_active_params}, {num_params}") - - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) - - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - prompt_length = encoded.size(0) - - torch.manual_seed(1234) - - aggregate_metrics = { - "tokens_per_sec": [], - } - - start = 0 - - for i in range(start, num_samples): - t0 = time.perf_counter() - - y = generate( - model, - encoded, - max_new_tokens, - temperature=temperature, - top_k=top_k, - ) - - t = time.perf_counter() - t0 - txt = tokenizer.decode(y.tolist()) - print(DELIMITER) - print(f"{GPT_FAST_PREFIX}") - print(f"Generated text for sample {i}: {txt}\n") - - tokens_generated = y.size(0) - prompt_length - tokens_sec = tokens_generated / t - sample_metrics = textwrap.dedent(f"""\ - {GPT_FAST_PREFIX} Sample Metrics - Time for inference {i+1}: {prompt_length} prompt tokens {tokens_generated} tokens generated, {t:.02f} sec total, {tokens_sec:.02f} tokens/sec - Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s""") - print( - textwrap.indent( - sample_metrics, - prefix=" ", - predicate=lambda line: not line.startswith(GPT_FAST_PREFIX), - ) - ) - aggregate_metrics["tokens_per_sec"].append(tokens_sec) - - # First print aggregate stats from original gpt-fast script - print(DELIMITER) - gpt_stats = textwrap.dedent(f"""\ - {GPT_FAST_PREFIX} Aggregate Stats - Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f} - Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB""") - - print( - textwrap.indent( - gpt_stats, - prefix=" ", - predicate=lambda line: not line.startswith(GPT_FAST_PREFIX), - ) - ) - - # Print performance summary from TransformerPerformanceCounter - print(DELIMITER) - total_stats_str = PERF_COUNTER.print_summary(show=False) - print(f"{PERF_COUNTER_PREFIX}\n{total_stats_str}") - print(f"\nSaving performance results to {save_path}") - PERF_COUNTER.to_json(save_path) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="TransformerPerformanceCounter Example", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "--prompt", type=str, default="Hello, my name is", help="Input prompt." - ) - parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.") - parser.add_argument( - "--max_new_tokens", type=int, default=2, help="Maximum number of new tokens." - ) - parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") - parser.add_argument( - "--temperature", type=float, default=0.8, help="Temperature for sampling." - ) - parser.add_argument( - "--checkpoint_path", - type=Path, - default=Path("./checkpoints/7B/model.pth"), - help="Model checkpoint path.", - ) - parser.add_argument( - "--save_path", - type=Path, - default=Path("performance_stats.json"), - help="Path to save performance stats.", - ) - args = parser.parse_args() - main(**vars(args)) diff --git a/torchao/prototype/profiler/__init__.py b/torchao/prototype/profiler/__init__.py deleted file mode 100644 index 976d4e3a05..0000000000 --- a/torchao/prototype/profiler/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Re-exports -from .device_spec import CUDADeviceSpec, DeviceSpec -from .performance_counter import ( - CUDAPerformanceTimer, - PerformanceCounterMode, - PerformanceStats, - PerformanceTimer, - TransformerPerformanceCounter, -) -from .utils import total_model_params - -__all__ = [ - "CUDAPerformanceTimer", - "PerformanceCounterMode", - "PerformanceStats", - "PerformanceTimer", - "TransformerPerformanceCounter", - "CUDADeviceSpec", - "DeviceSpec", - "total_model_params", -] diff --git a/torchao/prototype/profiler/device_spec.py b/torchao/prototype/profiler/device_spec.py deleted file mode 100644 index 040367583f..0000000000 --- a/torchao/prototype/profiler/device_spec.py +++ /dev/null @@ -1,421 +0,0 @@ -from dataclasses import dataclass, field, fields -from typing import Dict, Optional, Union - -import torch - -"""This module contains the device specs for theoretical peak performance calculations. - -- Contains a list of available chips and their corresponding theoretical peak FLOPs performance for various torch.dtypes. -- Exposes a DeviceSpec interface and a concrete CUDADeviceSpec implementation for CUDA gpus. Extendable to other device types. -- Where possible, the CUDADeviceSpec auto-populates its fields by utilizing `torch.cuda` API and `triton.runtime.driver`. - -""" -# Copied from https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/throughput.py -_AVAILABLE_GPU_SPECS: Dict[str, Dict[Union[str, torch.dtype], float]] = { - # Hopper - # source: https://resources.nvidia.com/en-us-tensor-core - "h100 nvl": { - torch.float64: 67e12, - torch.float32: 133.8e12, - "tfloat32": 989.4e12, - torch.bfloat16: 1978.8e12, - torch.float16: 1978.8e12, - torch.int8: 3957.8e12, - }, - "h100 sxm": { - torch.float64: 33.5e12, - torch.float32: 66.9e12, - "tfloat32": 494.7e12, - torch.bfloat16: 989.4e12, - torch.float16: 989.4e12, - torch.int8: 1978.9e12, - }, - "h100 pcie": { - torch.float64: 25.6e12, - torch.float32: 51.2e12, - "tfloat32": 378e12, - torch.bfloat16: 756e12, - torch.float16: 756e12, - torch.int8: 1513e12, - }, - # Ada - # source: https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf - "rtx 4090": { - torch.float32: 82.6e12, - "tfloat32": 82.6e12, - torch.bfloat16: 82.6e12, - torch.float16: 82.6e12, - torch.int8: 660.6e12, - "int4": 1321.2e12, - }, - "rtx 4080": { - torch.float32: 48.7e12, - "tfloat32": 48.7e12, - torch.bfloat16: 48.7e12, - torch.float16: 48.7e12, - torch.int8: 389.9e12, - "int4": 779.8e12, - }, - "l4": { - torch.float32: 30.3e12, - "tfloat32": 60e12, - torch.bfloat16: 121e12, - torch.float16: 121e12, - torch.int8: 242e12, - "int4": 484e12, - }, - "l40": { - torch.float32: 90.5e12, - "tfloat32": 90.5e12, - torch.bfloat16: 181e12, - torch.float16: 181e12, - torch.int8: 362e12, - "int4": 724e12, - }, - # Ampere - # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf - # sxm and pcie have same flop counts - "a100": { - torch.float64: 9.7e12, - torch.float32: 19.5e12, - "tfloat32": 156e12, - torch.bfloat16: 312e12, - torch.float16: 312e12, - torch.int8: 624e12, - }, - "a6000": { - torch.float32: 38.7e12, - "tfloat32": 77.4e12, - torch.bfloat16: 38.7e12, - torch.float16: 38.7e12, - torch.int8: 309.7e12, - "int4": 619.3e12, - }, - "a40": { - torch.float32: 37.4e12, - "tfloat32": 74.8e12, - torch.bfloat16: 37.4e12, - torch.float16: 37.4e12, - torch.int8: 299.3e12, - "int4": 598.7e12, - }, - # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf - "a10g": { - torch.float32: 31.2e12, - "tfloat32": 62.5e12, - torch.bfloat16: 125e12, - torch.float16: 125e12, - torch.int8: 250e12, - "int4": 500e12, - }, - "rtx 3090 ti": { - torch.float32: 40e12, - "tfloat32": 40e12, - torch.bfloat16: 40e12, - torch.float16: 40e12, - torch.int8: 320e12, - "int4": 640e12, - }, - "rtx 3090": { - torch.float32: 35.6e12, - "tfloat32": 35.6e12, - torch.bfloat16: 35.6e12, - torch.float16: 35.6e12, - torch.int8: 284e12, - "int4": 568e12, - }, - "rtx 3080 ti": { - torch.float32: 34.1e12, - "tfloat32": 34.1e12, - torch.bfloat16: 34.1e12, - torch.float16: 34.1e12, - torch.int8: 272.8e12, - "int4": 546.6e12, - }, - "rtx 3080": { - torch.float32: 29.8e12, - "tfloat32": 29.8e12, - torch.bfloat16: 29.8e12, - torch.float16: 29.8e12, - torch.int8: 238e12, - "int4": 476e12, - }, - "rtx 3070": { - torch.float32: 20.3e12, - "tfloat32": 20.3e12, - torch.bfloat16: 20.3e12, - torch.float16: 20.3e12, - torch.int8: 162.6e12, - "int4": 325.2e12, - }, - # Turing - # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf - # sxm and pcie have same flop counts - "t4": { - torch.float32: 8.1e12, - torch.float16: 65e12, - torch.int8: 130e12, - "int4": 260e12, - }, - # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf - "quadro rtx 5000": { - torch.float32: 11.2e12, - torch.float16: 89.2e12, - }, - "rtx 2080 super": { - torch.float32: 11.2e12, - torch.float16: 22.3e12, - torch.int8: 178.4e12, - "int4": 356.8e12, - }, - "rtx 2080 ti": { - torch.float32: 14.2e12, - torch.float16: 28.5e12, - torch.int8: 227.7e12, - "int4": 455.4e12, - }, - "rtx 2080": { - torch.float32: 10.6e12, - torch.float16: 21.2e12, - torch.int8: 169.6e12, - "int4": 339.1e12, - }, - # https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.pdf - "rtx 2070 super": { - torch.float32: 9.1e12, - torch.float16: 18.1e12, - torch.int8: 145e12, - "int4": 290e12, - }, - "titan rtx": { - torch.float32: 16.3e12, - torch.float16: 32.6e12, - torch.int8: 261e12, - "int4": 522e12, - }, - # Volta - # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf - "v100 sxm": { - torch.float64: 7.8e12, - torch.float32: 15.7e12, - torch.float16: 125e12, - }, - "v100 pcie": { - torch.float64: 7e12, - torch.float32: 14e12, - torch.float16: 112e12, - }, - "v100s pcie": { - torch.float64: 8.2e12, - torch.float32: 16.4e12, - torch.float16: 130e12, - }, -} - - -# Adapted from https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/throughput.py -def get_chip_name(device: int = 0) -> str: - device_name = torch.cuda.get_device_name(device) - chip = device_name.lower() - - if "h100" in chip: - if "hbm3" in chip: - chip = "h100 sxm" - elif "nvl" in chip: - chip = "h100 nvl" - elif "pcie" in chip or "hbm2e" in chip: - chip = "h100 pcie" - elif "l4" in chip: - chip = "l40" if "tesla" in chip else "l4" - elif "geforce rtx" in chip: - number = chip.split(" ")[3] - extra = "" - if "super" in chip: - extra = " super" - elif "ti" in chip: - extra = " ti" - chip = f"rtx {number}{extra}" - elif "a6000" in chip: - chip = "a6000" - elif "a100" in chip: - chip = "a100" - elif "a40" in chip: - chip = "a40" - elif "a10g" in chip: - chip = "a10g" - elif "t4" in chip: - chip = "t4" - elif "quadro rtx 5000" in chip: - chip = "quadro rtx 5000" - elif "titan rtx" in chip: - chip = "titan rtx" - elif "v100-sxm" in chip: - chip = "v100 sxm" - elif "v100-pcie" in chip: - chip = "v100 pcie" - elif "v100s-pcie" in chip: - chip = "v100s pcie" - else: - chip = None - return chip - - -def get_vram(device: int = 0) -> int: - device_props = torch.cuda.get_device_properties(device) - return device_props.total_memory - - -def get_bandwidth(device: int = 0) -> int: - try: - from triton.testing import get_dram_gbps - - bandwidth = get_dram_gbps(device) * 1e9 - except ImportError: - print("Could not import triton to get DRAM Gbps. Please install triton") - bandwidth = None - return bandwidth - - -def get_flops_by_dtype(chip_name: str) -> dict[torch.dtype, float]: - return _AVAILABLE_GPU_SPECS.get(chip_name, None) - - -@dataclass -class DeviceSpec: - """ - Abstract device specs for theoretical peak performance calculations. - - Fields will be auto-populated in __post_init__ if not already specified - and if data is available - - bandwidth (bytes /s) - - flops_per_s (FLOP / s) - - vram (bytes) - - dtype (torch.dtype) dtype used for theoretical peak performance - - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOP / s - """ - - device_type: str - name: Optional[str] = None - bandwidth: Optional[int] = None - flops_per_s: Optional[int] = None - vram: Optional[int] = None - dtype: Optional[torch.dtype] = None - flops_by_dtype: dict = field(default_factory=dict) - - def _post_init_check(self): - assert ( - self.bandwidth is not None - ), "GPU bandwidth is None - please specify the bandwidth in GB/s in order to enable speed of light calculations" - assert ( - self.dtype is not None - ), "GPU dtype is None - please specify the dtype in order to enable speed of light calculations" - assert ( - self.flops_per_s is not None - ), "GPU flops_per_s is None - please specify the flops_per_s in FLOP/s in order to enable speed of light calculations" - self.flops_by_dtype.update({self.dtype: self.flops_per_s}) - - # Not needed for downstream calculations atm, no need to assert - if self.vram is None: - print("GPU vram is None - please specify the vram in bytes") - - def __setattr__(self, name, value): - # Check if the attribute is already defined - if name in {f.name for f in fields(self)}: - super().__setattr__(name, value) - else: - raise AttributeError( - f"Cannot add new attribute '{name}' to {self.__class__.__name__}" - ) - - def __str__(self): - if self.bandwidth is not None: - formatted_bw = f"{self.bandwidth / 1e9:,.1f}GB/s" - if self.flops_per_s is not None: - formatted_flops = f"{self.flops_per_s / 1e12:,.1f}TFLOPs" - if self.vram is not None: - formatted_vram = f"{self.vram / 1e9:,.1f}GB" - return f"DeviceSpec(device_type={self.device_type}, name={self.name}, dtype={self.dtype}, bandwidth={formatted_bw}, flops={formatted_flops}, vram={formatted_vram})" - - @property - def roofline_balancepoint(self): - """ - Arithmetic intensity (FLOP / byte) transition point from - memory-bound to compute-bound regime. - - This is the ridgepoint of the roofline curve. - """ - assert ( - self.bandwidth is not None - ), "Please set bandwidth in order to calculate roofline balancepoint" - assert ( - self.flops_per_s is not None - ), "Please set flops_per_s in order to calculate roofline balancepoint" - - return self.flops_per_s / self.bandwidth - - -@dataclass -class CUDADeviceSpec(DeviceSpec): - """ - CUDA specs for theoretical peak performance, conformant with DeviceSpec interface. - - Fields will be auto-populated in __post_init__ if not specified - and if data is available. - - See _AVAILABLE_GPU_SPECS for a list of available chip data. - - Fields and expected units: - - device (int): CUDA device index - - name (str): name of the device - - bandwidth (bytes /s): memory bandwidth in bytes / s - - flops_per_s (FLOP / s): FLOPs per second - - vram (bytes): VRAM in bytes - - dtype (torch.dtype): dtype used for theoretical peak performance - - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOP / s - - use_tensorcores (bool): whether to use tensorcores if dtype == torch.float32 - """ - - device_type: str = "cuda" - # Device index - device: Optional[int] = 0 - # Whether to use tfloat32 FLOPs for dtype == torch.float32 - # We assume that tensorcores will always be used for fp16, int8, and other sub-single precision dtypes - use_tensorcores: bool = True - - def __post_init__(self): - # Populate fields if not already populated - self.name = torch.cuda.get_device_name(self.device) - - # Memory bandwidth in bytes / s - if self.bandwidth is None: - self.bandwidth = get_bandwidth() - - # FLOPs / s - if self.flops_per_s is None: - chip_name = get_chip_name(self.device) - if chip_name is None: - print(f"No FLOPs data available for device name {self.name}") - else: - flops_by_dtype = get_flops_by_dtype(chip_name) - if flops_by_dtype is not None: - self.flops_by_dtype.update(flops_by_dtype) - - # Populate flops if not already populated - if flops_by_dtype is not None and self.dtype in flops_by_dtype: - self.flops_per_s = flops_by_dtype[self.dtype] - - if self.dtype == torch.float32: - use_tf32 = "tfloat32" in flops_by_dtype and self.use_tensorcores - - if use_tf32: - self.flops_per_s = flops_by_dtype["tfloat32"] - else: - print( - f"Could not find FLOPs for dtype {self.dtype} for device {self.name}" - ) - # Vram - if self.vram is None: - self.vram = get_vram() - - # Issue post check warnings - self._post_init_check() diff --git a/torchao/prototype/profiler/performance_counter.py b/torchao/prototype/profiler/performance_counter.py deleted file mode 100644 index d79625d55c..0000000000 --- a/torchao/prototype/profiler/performance_counter.py +++ /dev/null @@ -1,597 +0,0 @@ -import inspect -import json -import math -import textwrap -import time -import warnings -from collections import defaultdict -from contextlib import contextmanager -from copy import deepcopy -from dataclasses import asdict, dataclass -from functools import partial -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import torch -from torch.utils._pytree import tree_map -from torch.utils.flop_counter import FlopCounterMode - -from .device_spec import DeviceSpec - -aten = torch.ops.aten - - -class DeviceInfoMissing(UserWarning): - pass - - -# Prevent excessive output -warnings.simplefilter("once", DeviceInfoMissing) - - -class PerformanceCounterMode(FlopCounterMode): - """ - ``PerformanceCounterMode`` extends FlopCounterMode to track IO in addition to flops. - - It does this using a ``TorchDispatchMode`` per `FlopCounterMode` and tracks the - inputs and outputs of each operator, organized by module. - - In addition to the methods exposed by FlopCounterMode, the following methods are - available: - - ``get_io_counts``: returns a dictionary of module names and their associated IO counts by aten operator - - ``get_total_io``: returns the total number of IO operations across all modules - - ``get_summary_io_counts``: returns a summary of the IO counts for each module (totals by operator) - - ``get_summary_flop_counts``: returns a summary of the flop counts for each module (totals by operator) - """ - - def __init__(self, display=False, depth=10, debug=False): - self.debug = debug - self.io_counts = defaultdict(lambda: defaultdict(int)) - super().__init__(display=display, depth=depth) - - def get_io_counts(self): - return {k: dict(v) for k, v in self.io_counts.items()} - - def get_total_io(self): - return sum(self.io_counts["Global"].values()) - - def _get_io_sizes(self, args): - sizes = tree_map( - lambda x: x.numel() * x.element_size() - if isinstance(x, torch.Tensor) - else 0, - args, - ) - if not hasattr(sizes, "__len__"): - sizes = [sizes] - return sizes - - def get_summary_flop_counts(self): - flop_counts = self.get_flop_counts() - return {k: sum(v.values()) for k, v in flop_counts.items()} - - def get_summary_io_counts(self): - io_counts = self.get_io_counts() - return {k: sum(v.values()) for k, v in io_counts.items()} - - def _nearest_power_of_10(self, x): - if x == 0: - return x, 0 - - power = int(math.floor(math.log10(abs(x)) / 3)) - scaled_value = x / (10 ** (3 * power)) - - return scaled_value, power - - def pretty_summary_counts(self, type="flops", precision=2, depth=None): - assert type in ["flops", "io"] - metric_units = { - 0: "", - 1: "k", - 2: "M", - 3: "G", - 4: "T", - 5: "P", - 6: "E", - 7: "Z", - 8: "Y", - } - - if depth is None: - depth = self.depth - summary_counts = ( - self.get_summary_flop_counts() - if type == "flops" - else self.get_summary_io_counts() - ) - keys_to_print = [k for k in summary_counts.keys() if len(k.split(".")) <= depth] - units = "FLOPs" if type == "flops" else "B" - summary_str = [] - for k in sorted(keys_to_print, key=lambda x: len(x.split("."))): - if k == "Global" or k is None: - continue - spaces = " " * (len(k.split(".")) - 1) - scaled_val, power = self._nearest_power_of_10(summary_counts[k]) - formatted_val = f"{scaled_val:.{precision}f}{metric_units[power]}{units}" - summary_str.append(f"{spaces}{k}: {formatted_val}") - - return "\n".join(summary_str) - - def _count_io(self, func_packet, out, args, kwargs): - arg_sizes = self._get_io_sizes(args) - kwargs_sizes = self._get_io_sizes(kwargs.values()) - out_sizes = self._get_io_sizes(out) - arg_size, kwargs_size, out_size = ( - sum(arg_sizes), - sum(kwargs_sizes), - sum(out_sizes), - ) - return arg_size, kwargs_size, out_size - - def _count_flops(self, func_packet, out, args, kwargs): - if func_packet in self.flop_registry: - flop_count_func = self.flop_registry[func_packet] - flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator] - arg_size, kwarg_size, out_size = self._count_io( - func_packet, out, args, kwargs - ) - total_size = arg_size + kwarg_size + out_size - - for par in set(self.mod_tracker.parents): - if self.debug: - print(f"Counting flops for {par}, {func_packet}: {flop_count}") - print( - f"Counting io for {par}, {func_packet}: {sum([arg_size, kwarg_size, out_size])} = {arg_size} + {kwarg_size} + {out_size}" - ) - self.flop_counts[par][func_packet] += flop_count - self.io_counts[par][func_packet] += total_size - - return out - - -class PerformanceTimer: - """ - Context manager that records the latency, io, and flops of a torch operator / module. - - Timing is done using `time.perf_counter` and can be overridden to use a different - timer (see `CUDAPerformanceTimer`). - - IO and FLOPs are recorded using `PerformanceCounterMode`. - - Available attributes: - name: str - precision: int - display: bool - depth (int): passed to `PerformanceCounterMode` if displaying and determines depth of module tree to display. - **Note**: these attributes are primarily used for debugging when using the `PerformanceTimer` standalone. - The TransformerPerformanceCounter class is a higher-level API that should be used instead. - - """ - - def __init__(self, name, precision=1, display=False, depth=10): - self.name = name - self.precision = precision - self.display = display - self.depth = depth - self.perf_counter = PerformanceCounterMode(display=display, depth=depth) - - def __enter__(self): - self.start = time.perf_counter() - self.perf_counter.__enter__() - return self - - def _print_exit_msg(self): - gflops = round(self.total_flops / 1e9, self.precision) - ms = round(self.latency * 1e3, self.precision) - if self.display: - print(f"{self.name.upper()}: latency = {ms} ms, FLOPS = {gflops} GFLOPs") - - def __exit__(self, type, value, traceback): - self.end = time.perf_counter() - # Convert to ms - self.latency = self.end - self.start - self.perf_counter.__exit__(type, value, traceback) - if self.display: - self._print_exit_msg() - - @property - def total_flops(self): - return self.perf_counter.get_total_flops() - - @property - def total_io(self): - return self.perf_counter.get_total_io() - - @property - def flops_table(self): - return self.perf_counter.get_table() - - def get_summary_flop_counts(self): - return self.perf_counter.get_summary_flop_counts() - - def get_summary_io_counts(self): - return self.perf_counter.get_summary_io_counts() - - @property - def flop_counts(self): - return self.perf_counter.get_flop_counts() - - @property - def io_counts(self): - return self.perf_counter.get_io_counts() - - def get_pretty_summary(self, depth): - return self.perf_counter.pretty_summary_counts( - depth=depth if depth is not None else self.depth - ) - - -class CUDAPerformanceTimer(PerformanceTimer): - """ - `PerformanceTimer` that uses `cudaEvents` to record latency. - """ - - def __enter__(self): - self.start = torch.cuda.Event(enable_timing=True) - self.end = torch.cuda.Event(enable_timing=True) - self.start.record() - self.perf_counter = PerformanceCounterMode( - display=self.display, depth=self.depth - ) - self.perf_counter.__enter__() - return self - - def __exit__(self, type, value, traceback): - self.end.record() - torch.cuda.synchronize() - # Convert from ms to s - self.latency = self.start.elapsed_time(self.end) * 1e-3 - self.perf_counter.__exit__(type, value, traceback) - - if self.display: - self._print_exit_msg() - - -def to_nearest_power_of_10(x, precision=2): - # Dictionary mapping powers of 10 to their metric abbreviations - metric_units = {0: "", -6: "ยต", -3: "m", 6: "M", 9: "G", 12: "T"} - - # Determine the closest power of 10 - if x == 0: - return f"{x:.{precision}f}" - - power = int(math.floor(math.log10(abs(x)))) - # Adjust power to fit within the given metric units - powers = sorted(metric_units.keys()) - closest_power = min(powers, key=lambda p: abs(p - power)) - - # Calculate the value formatted to the closest power of 10 - value = x / 10**closest_power - - # Map the power to the metric unit - unit = metric_units.get(closest_power, f"e{closest_power}") - - return f"{value:,.{precision}f} {unit}" - - -class DictMixin: - """ - Enables dict-like interface to dataclasses. - """ - - def __getitem__(self, key): - if hasattr(self, key): - return getattr(self, key) - else: - raise KeyError(key) - - def __setitem__(self, key, value): - setattr(self, key, value) - - def __contains__(self, key): - return hasattr(self, key) - - def __iter__(self): - for key in self.__dict__: - yield key - - -def _get_property_methods(cls): - return [ - name for name, _ in inspect.getmembers(cls, lambda m: isinstance(m, property)) - ] - - -@dataclass -class PerformanceStats(DictMixin): - """ - Data struct that stores performance statistics. - - Attrs: - num_tokens (int): number of tokens processed - latency (float): latency in seconds - total_flops (int): total FLOPs - total_io (int): total data movement in bytes - flops_summary (Dict[str, int]): summary of FLOPs by module - io_summary (Dict[str, int]): summary of data movement in bytes by module - flop_counts (Dict[str, Dict[Any, int]]): FLOP counts by module and operation - io_counts (Dict[str, Dict[Any, int]]): data movement by module and operation - device_bandwidth (Optional[float]): device bandwidth in bytes per second - device_flops_per_s (Optional[float]): device FLOPs per second - - Additionally, the following derived properties are available: - token_throughput (float): number of tokens processed per second - achieved_flops_per_s (float): achieved FLOPs per second - achieved_bandwidth (float): achieved data movement in bytes per second - theoretical_io_latency (Optional[float]): theoretical I/O latency in seconds, set to None if - no device bandwidth is available. - theoretical_compute_latency (Optional[float]): theoretical compute latency in seconds, set to None if - no device FLOPs are available. - """ - - label: str - num_tokens: int - latency: float - total_flops: int - total_io: int - flops_summary: Dict[str, int] - io_summary: Dict[str, int] - flop_counts: Dict[str, Dict[Any, int]] - io_counts: Dict[str, Dict[Any, int]] - device_bandwidth: Optional[float] = None - device_flops_per_s: Optional[float] = None - - @property - def token_throughput(self): - return self.num_tokens / self.latency - - @property - def achieved_flops_per_s(self): - return self.total_flops / self.latency - - @property - def achieved_bandwidth(self): - return self.total_io / self.latency - - @property - def theoretical_io_latency(self): - if self.device_bandwidth is not None: - return self.total_io / self.device_bandwidth - else: - warnings.warn( - "Device bandwidth is not specified. Please specify the device bandwidth to enable io latency calculation" - ) - return None - - @property - def theoretical_compute_latency(self): - if self.device_flops_per_s is not None: - return self.total_flops / self.device_flops_per_s - else: - warnings.warn( - "Device flops_per_s is not specified. Please specify the device throughput to enable compute latency calculation" - ) - return None - - @property - def bandwidth_utilization(self): - if self.device_bandwidth is not None: - return self.achieved_bandwidth / self.device_bandwidth - else: - warnings.warn( - "Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation" - ) - return None - - @property - def flops_utilization(self): - if self.device_flops_per_s is not None: - return self.achieved_flops_per_s / self.device_flops_per_s - else: - warnings.warn( - "Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation" - ) - return None - - def _format(self, value, suffix, precision=2, round=True): - if round: - return to_nearest_power_of_10(value, precision=precision) + suffix - return f"{value:.{precision}f} " + suffix - - def __str__(self): - txt = textwrap.dedent(f"""\ - {self.label}: - Latency = {self._format(self.latency, "s")} - Tokens - Total: {self.num_tokens} tokens - Throughput: {self.token_throughput:,.0f} tokens/s - IO - Total: {self._format(self.total_io, "B")} - Throughput: {self._format(self.achieved_bandwidth, "B/s")} - Theoretical Latency: {self._format(self.theoretical_io_latency, "s") if self.theoretical_io_latency is not None else "N/A"} - FLOPs - Total: {self._format(self.total_flops, "FLOPs")} - Throughput: {self._format(self.achieved_flops_per_s, "FLOPs/s")} - Theoretical Latency: {self._format(self.theoretical_compute_latency, "s") if self.theoretical_compute_latency is not None else "N/A"} - Utilization - Bandwidth: {self._format(self.bandwidth_utilization, round=False, precision=4, suffix="%") if self.bandwidth_utilization is not None else "N/A"} - FLOPs: {self._format(self.flops_utilization, round=False, precision=4, suffix="%") if self.flops_utilization is not None else "N/A"}""") - - return txt - - def to_dict(self): - d = asdict(self) - # Update dict with properties - props = _get_property_methods(self.__class__) - d.update({prop: getattr(self, prop) for prop in props}) - - return d - - -class TransformerPerformanceCounter: - """ - Context manager-like class for tracking performance across multiple calls - to a Transformer model. - - Provides properties for accessing performance stats for data movement and FLOPs for each context as well as - summary stats across all contexts. - Additionally, if a device_spec is provided, theoretical peak bandwidth / FLOPs stats will be available. - - See `PerformanceStats` struct for description of tracked metrics. - - Example: - >>> manager = TransformerPerformanceCounter(device_spec=device_spec) - >>> with manager.count(label="prefill", num_tokens=x.numel()): - >>> out = model(encoded_prompt) - >>> manager.print_summary(labels=["prefill"]) # prints recorded stats for "prefill" context - >>> with manager.count(label="decode", num_tokens=1): - >>> out = model(out[-1]) - >>> manager.print_summary(labels=["decode"]) # prints recorded stats for "decode" context - >>> print(manager.print_summary) # prints accumulated stats across all contexts - """ - - def __init__( - self, - depth=10, - timer_cls: PerformanceTimer = PerformanceTimer, - device_spec: DeviceSpec = None, - ): - super().__init__() - self._counts: Dict[str, PerformanceStats] = {} - self._depth = depth - self.timer_cls = timer_cls - self.device_spec = device_spec - - @contextmanager - def count(self, label: str, num_tokens: int): - perf_timer = self.timer_cls(name=label, depth=self._depth) - perf_timer.__enter__() - try: - yield self - finally: - perf_timer.__exit__(None, None, None) - stats = PerformanceStats( - label=label, - num_tokens=num_tokens, - latency=perf_timer.latency, - total_flops=perf_timer.total_flops, - total_io=perf_timer.total_io, - flops_summary=perf_timer.get_summary_flop_counts(), - io_summary=perf_timer.get_summary_io_counts(), - flop_counts=perf_timer.flop_counts, - io_counts=perf_timer.io_counts, - device_bandwidth=self.device_spec.bandwidth - if self.device_spec is not None - else None, - device_flops_per_s=self.device_spec.flops_per_s - if self.device_spec is not None - else None, - ) - self._counts[label] = stats - - @property - def counts(self): - return self._counts - - def get_counts(self): - return self._counts - - @property - def total_flops(self): - return sum(count.total_flops for count in self._counts.values()) - - @property - def total_io(self): - return sum(count.total_io for count in self._counts.values()) - - @property - def total_tokens(self): - return sum(count.num_tokens for count in self._counts.values()) - - @property - def total_time(self): - return sum(count.latency for count in self._counts.values()) - - def _summarize_stat(self, key): - return { - label: getattr(self._counts[label], key) for label in self._counts.keys() - } - - @property - def flops_summary(self): - return self._summarize_stat(key="flops_summary") - - @property - def io_summary(self): - return self._summarize_stat(key="io_summary") - - @property - def flop_counts_summary(self): - return self._summarize_stat(key="flop_counts") - - @property - def io_counts_summary(self): - return self._summarize_stat(key="io_counts") - - @property - def stats_summary(self): - stats = PerformanceStats( - label="Performance Summary", - num_tokens=self.total_tokens, - latency=self.total_time, - total_flops=self.total_flops, - total_io=self.total_io, - flops_summary=self.flops_summary, - io_summary=self.io_summary, - flop_counts=self.flop_counts_summary, - io_counts=self.io_counts_summary, - device_bandwidth=self.device_spec.bandwidth - if self.device_spec is not None - else None, - device_flops_per_s=self.device_spec.flops_per_s - if self.device_spec is not None - else None, - ) - - return stats - - def print_summary(self, labels: list[str] = None, show: bool = False): - _print = partial(print, flush=True, end="\n") - # Delegate to __str__ of PerformanceStats for pretty printing - if labels is None: - text = str(self.stats_summary) - if show: - _print(text) - return text - else: - txts = [] - for label in labels: - text = str(self._counts[label]) - if show: - _print(text) - txts.append(text) - return "\n".join(txts) - - def to_dict(self): - # Convert flop_counts from OpOverloadPackets to str - # Then delegate to PerformanceStats `to_dict`, which updates with derived metrics (property methods) - counts = deepcopy(self._counts) - for label, label_counts in counts.items(): - counts[label]["flop_counts"] = { - mod: {str(op): count for op, count in op_count.items()} - for mod, op_count in label_counts["flop_counts"].items() - } - counts[label]["io_counts"] = { - mod: {str(op): count for op, count in op_count.items()} - for mod, op_count in label_counts["io_counts"].items() - } - counts[label] = counts[label].to_dict() - - return counts - - def to_json(self, path: Union[str, Path] = None): - d = self.to_dict() - if path: - with open(path, "w") as f: - f.write(json.dumps(d, indent=2)) - return d diff --git a/torchao/prototype/profiler/utils.py b/torchao/prototype/profiler/utils.py deleted file mode 100644 index 9276dd37b1..0000000000 --- a/torchao/prototype/profiler/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -import inspect - -import torch - -_HUGGINGFACE_CAUSAL_LM_BASE_CLASSES = [ - "causallm", - "pretrainedmodel", - "generationmixin", -] - - -def _get_all_base_classes(object): - return [cls.__name__.lower() for cls in inspect.getmro(object.__class__)] - - -def total_model_params( - model: torch.nn.Module, - exclude_embeddings: bool = True, - embedding_key: str = "tok_embeddings", -) -> int: - """ - Calculate total params of a HuggingFace CausalLM model or gpt-fast model - """ - num_params = sum(p.numel() for p in model.parameters()) - - # Exclude embeddings when calculating FLOP since they don't contribute to FLOP count - if exclude_embeddings: - # Not the cleanest, but check if any base class of the model is in _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES - if ( - len( - set(_get_all_base_classes(model)).intersection( - _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES - ) - ) - > 0 - ): - num_params -= model.model.embed_tokens.weight.numel() - elif hasattr(model, embedding_key): - num_params -= getattr(model, embedding_key).weight.numel() - else: - raise ValueError( - f"Could not find embedding in model {model.__class__.__name__}, please specify embedding attribute key" - ) - return num_params From 2d318064ca247d70c1c1f44994460d81b8368270 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 10 Mar 2025 17:06:29 -0700 Subject: [PATCH 2/2] Remove prototype/dtypes folder --- test/dtypes/test_bitnet.py | 91 ------ test/dtypes/test_uint2.py | 40 --- test/prototype/test_bitpacking_gen.py | 41 --- torchao/prototype/dtypes/__init__.py | 7 - torchao/prototype/dtypes/bitnet.py | 200 ------------- torchao/prototype/dtypes/uint2.py | 280 ------------------- torchao/prototype/dtypes/uintgen.py | 385 -------------------------- 7 files changed, 1044 deletions(-) delete mode 100644 test/dtypes/test_bitnet.py delete mode 100644 test/dtypes/test_uint2.py delete mode 100644 test/prototype/test_bitpacking_gen.py delete mode 100644 torchao/prototype/dtypes/__init__.py delete mode 100644 torchao/prototype/dtypes/bitnet.py delete mode 100644 torchao/prototype/dtypes/uint2.py delete mode 100644 torchao/prototype/dtypes/uintgen.py diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py deleted file mode 100644 index e248b04b05..0000000000 --- a/test/dtypes/test_bitnet.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -from torchao.prototype.dtypes import BitnetTensor -from torchao.prototype.dtypes.uint2 import unpack_uint2 -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - -@pytest.fixture(autouse=True) -def run_before_and_after_tests(): - # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 - - # setup (currently do nothing) - - # tests will run here - yield - - # teardown - # avoid dynamo cache limit issues - torch._dynamo.reset() - - -@pytest.fixture -def bitnet_tensor(): - input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) - return BitnetTensor.from_unpacked(input_tensor) - - -def test_copy(bitnet_tensor): - copied_tensor = bitnet_tensor.clone() - assert torch.equal(bitnet_tensor.elem, copied_tensor.elem) - - -def test_transpose(bitnet_tensor): - transposed_tensor = bitnet_tensor.t() - expected_tensor = unpack_uint2(bitnet_tensor.elem).t() - assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) - - -def test_multiply(bitnet_tensor): - w_t = torch.randint(0, 15, (4, 16), dtype=torch.uint8) - w = BitnetTensor.from_unpacked(w_t) - torch.addmm(torch.Tensor([1]), bitnet_tensor, w) - - -@pytest.mark.parametrize( - "dtype", - [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], -) -def test_conversion(bitnet_tensor, dtype): - converted_tensor = bitnet_tensor.to(dtype) - expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype) - assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) - - -def _apply_weight_only_uint2_quant(model): - def fn(mod): - mod.weight = torch.nn.Parameter( - BitnetTensor.from_float(mod.weight), requires_grad=False - ) - return mod - - _replace_with_custom_fn_if_matches_filter( - model, - lambda mod: fn(mod), - lambda mod, fqn: isinstance(mod, torch.nn.Linear), - ) - - -@pytest.mark.skipif( - TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies" -) -@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) -def test_uint2_quant(input_shape): - device = "cuda" if torch.cuda.is_available() else "cpu" - x = torch.randn(*input_shape).to(device) - m = nn.Sequential(nn.Linear(4, 16)).to(device) - y_ref = m(x) - _apply_weight_only_uint2_quant(m) - y_wo = m(x) - assert y_ref.shape == y_wo.shape - torch.compile(m, fullgraph=True)(x) - - -if __name__ == "__main__": - pytest.main(__file__) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py deleted file mode 100644 index f6faaea10d..0000000000 --- a/test/dtypes/test_uint2.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -import torch - -from torchao.prototype.dtypes import UInt2Tensor -from torchao.prototype.dtypes.uint2 import unpack_uint2 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - -@pytest.fixture -def uint2_tensor(): - input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) - return UInt2Tensor(input_tensor) - - -def test_copy(uint2_tensor): - copied_tensor = uint2_tensor.clone() - assert torch.equal(uint2_tensor.elem, copied_tensor.elem) - - -def test_transpose(uint2_tensor): - transposed_tensor = uint2_tensor.t() - expected_tensor = unpack_uint2(uint2_tensor.elem).t() - assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) - - -@pytest.mark.parametrize( - "dtype", - [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], -) -def test_conversion(uint2_tensor, dtype): - converted_tensor = uint2_tensor.to(dtype) - expected_tensor = unpack_uint2(uint2_tensor.elem).to(dtype) - assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) - - -if __name__ == "__main__": - pytest.main(__file__) diff --git a/test/prototype/test_bitpacking_gen.py b/test/prototype/test_bitpacking_gen.py deleted file mode 100644 index 288ac1e4fc..0000000000 --- a/test/prototype/test_bitpacking_gen.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -import torch - -from torchao.prototype.dtypes.uintgen import ( - pack_uint2, - pack_uint3, - pack_uint4, - pack_uint5, - pack_uint6, - pack_uint7, - unpack_uint2, - unpack_uint3, - unpack_uint4, - unpack_uint5, - unpack_uint6, - unpack_uint7, -) - - -@pytest.mark.parametrize( - "pack_fn, unpack_fn, bit_count", - [ - (pack_uint2, unpack_uint2, 2), - (pack_uint3, unpack_uint3, 3), - (pack_uint4, unpack_uint4, 4), - (pack_uint5, unpack_uint5, 5), - (pack_uint6, unpack_uint6, 6), - (pack_uint7, unpack_uint7, 7), - ], -) -def test_uint_packing(pack_fn, unpack_fn, bit_count): - x = torch.arange(0, 256, dtype=torch.uint8) - y = pack_fn(x) - z = unpack_fn(y) - k = z.view(-1, 2**bit_count) - check = torch.arange(0, 2**bit_count, dtype=torch.uint8).repeat(k.size(0), 1) - assert torch.all(k == check), f"Failed for {bit_count}-bit packing" - - -if __name__ == "__main__": - pytest.main(__file__) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py deleted file mode 100644 index 9393737aff..0000000000 --- a/torchao/prototype/dtypes/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .bitnet import BitnetTensor -from .uint2 import UInt2Tensor - -__all__ = [ - "BitnetTensor", - "UInt2Tensor", -] diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py deleted file mode 100644 index 72b444acd4..0000000000 --- a/torchao/prototype/dtypes/bitnet.py +++ /dev/null @@ -1,200 +0,0 @@ -import torch - -from torchao.prototype.dtypes.uint2 import UInt2Tensor, pack_uint2, unpack_uint2 - -BITNET_OPS_TABLE = {} - - -def implements(aten_ops): - def decorator(fn): - for op in aten_ops: - BITNET_OPS_TABLE[op] = fn - return fn - - return decorator - - -def _quantize_int2(x: torch.Tensor) -> torch.Tensor: - # Quantize the input tensor to int2 - quant = x.sign() + 1 - quant = BitnetTensor.from_unpacked(quant.to(torch.uint8)) - return quant - - -class BitnetTensor(UInt2Tensor): - def __new__(cls, input_tensor: torch.Tensor, **kwargs): - return super(BitnetTensor, cls).__new__(cls, input_tensor, **kwargs) - - def __init__(self, input_tensor: torch.Tensor, **kwargs): - super(BitnetTensor, self).__init__(input_tensor, **kwargs) - - @staticmethod - def __tensor_unflatten__(flattened, *meta): - # TODO - meta is not None, is it ok? - elem = flattened["elem"] - return BitnetTensor(elem) - - @classmethod - def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": - return cls(pack_uint2(unpacked)) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - def allowed_subclasses(type): - return ( - issubclass(cls, type) - or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) - or issubclass( - torch._subclasses.functional_tensor.FunctionalTensor, type - ) - ) - - if not all(allowed_subclasses(t) for t in types): - return NotImplemented("Bitnet, Up to the next one to handle") - - if func in BITNET_OPS_TABLE: - return BITNET_OPS_TABLE[func](func, args, kwargs) - raise NotImplementedError( - f"Bitnet dispatch: attempting to run {func}, this is not supported" - ) - - @classmethod - def from_float(cls, w: torch.Tensor): - w_intq = _quantize_int2(w) - w_int2 = w_intq.to(device=w.device) - return w_int2 - - def clone(self): - return BitnetTensor(self.elem.clone()) - - def copy_(self, src): - self.elem.copy_(src.elem) - return self - - def tolist(self): - data = unpack_uint2(self.elem).tolist() - return data - - def __repr__(self): - try: - data = unpack_uint2(self.elem).tolist() - except AssertionError: - data = f"Tensor of shape {self.shape} and dtype {self.elem.dtype}" - return f"BitnetTensor({data}, dtype={self.elem.dtype})" - - def to(self, *args, **kwargs): - if len(args) == 1 and isinstance(args[0], torch.dtype): - dtype = args[0] - if dtype == torch.int8: - return unpack_uint2(self.elem).view(self.shape).view(torch.int8) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return unpack_uint2(self.elem).to(torch.int8).to(dtype) - elif dtype == torch.uint8: - return unpack_uint2(self.elem).view(torch.uint8) - elif isinstance(self, BitnetTensor): - return self - if "device" in kwargs: - device = kwargs["device"] - return BitnetTensor(self.elem.to(device=device)) - - return super().to(*args, **kwargs) - - -@implements([torch.ops.aten.mm.default]) -def mm(func, args, kwargs): - x, weight = args - if isinstance(x, BitnetTensor): - x = unpack_uint2(x.elem).to(torch.float32) - if isinstance(weight, BitnetTensor): - weight = unpack_uint2(weight.elem).to(torch.float32) - y = torch.mm(x, weight) - return y - - -@implements([torch.ops.aten.addmm.default]) -def addmm(func, args, kwargs): - bias, x, weight = args - if isinstance(x, BitnetTensor): - x = unpack_uint2(x.elem).to(torch.float32) - if isinstance(weight, BitnetTensor): - weight = unpack_uint2(weight.elem).to(torch.float32) - if bias is not None: - bias = bias.to(torch.float32) - y = torch.addmm(bias, x, weight) - return y - - -@implements([torch.ops.aten.t.default]) -def t(func, args, kwargs): - (tensor,) = args - unpacked = unpack_uint2(tensor.elem).to(tensor.device) - transposed = unpacked.t() - return BitnetTensor(pack_uint2(transposed)) - - -@implements([torch.ops.aten.detach.default]) -def detach(func, args, kwargs): - (tensor,) = args - return tensor - - -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, args, kwargs): - (tensor, dtype) = args - if dtype == torch.int8: - return unpack_uint2(tensor.elem).view(torch.uint8) - 1 - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return unpack_uint2(tensor.elem).to(torch.int8).to(dtype) - elif dtype == torch.uint8: - return unpack_uint2(tensor.elem).view(torch.uint8) - elif isinstance(tensor, BitnetTensor): - return tensor.elem - raise NotImplementedError(f"to {dtype} not supported") - - -@implements([torch.ops.aten._to_copy.default]) -def _to_copy(func, args, kwargs): - (tensor,) = args - dtype = kwargs["dtype"] - if dtype == torch.int8: - return BitnetTensor( - unpack_uint2(tensor).view(tensor.shape).view(torch.int8) - 1 - ) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return BitnetTensor(tensor.to(torch.int8).to(dtype)) - elif isinstance(tensor, BitnetTensor): - return BitnetTensor(tensor) - raise NotImplementedError(f"to {dtype} not supported") - - -@implements([torch.ops.aten.clone.default]) -def clone(func, args, kwargs): - (tensor,) = args - return tensor.clone() - - -@implements([torch.ops.aten.allclose.default]) -def allclose(func, args, kwargs): - (a, b) = args - return torch.allclose(a.elem, b.elem, **kwargs) diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py deleted file mode 100644 index d54e541751..0000000000 --- a/torchao/prototype/dtypes/uint2.py +++ /dev/null @@ -1,280 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Tuple - -import torch -import torch._prims_common as utils - -from torchao.utils import fill_defaults - -UINT2_OPS_TABLE: Dict[Any, Any] = {} - - -def implements(aten_ops): - def decorator(fn): - for op in aten_ops: - UINT2_OPS_TABLE[op] = fn - return fn - - return decorator - - -def down_size(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by 4" - return (*size[:-1], size[-1] // 4) - - -def up_size(size): - return (*size[:-1], size[-1] * 4) - - -def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = (uint8_data >> 6) & 0b11 - second_elements = (uint8_data >> 4) & 0b11 - third_elements = (uint8_data >> 2) & 0b11 - fourth_elements = uint8_data & 0b11 - return torch.stack( - (first_elements, second_elements, third_elements, fourth_elements), dim=-1 - ).view(up_size(shape)) - - -def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - assert shape[-1] % 4 == 0, f"{shape}, last dim not divisible by 4" - uint8_data = uint8_data.contiguous().view(-1) - packed_data = ( - uint8_data[::4] << 6 - | uint8_data[1::4] << 4 - | uint8_data[2::4] << 2 - | uint8_data[3::4] - ).view(down_size(shape)) - return packed_data - - -@dataclass -class SubclassTensorArgs: - original_shape: torch.Size - original_strides: Tuple - storage_offset: int - dtype: torch.dtype - device: torch.device - requires_grad: bool - - -class UInt2Tensor(torch.Tensor): - def __new__(cls, input_tensor: torch.Tensor): - assert input_tensor.dtype == torch.uint8 - tensor_meta = SubclassTensorArgs( - input_tensor.size(), - input_tensor.stride(), - input_tensor.storage_offset(), - cls, - input_tensor.device, - input_tensor.requires_grad, - ) - uint2i_tensor = torch.Tensor._make_wrapper_subclass( - cls, - up_size(tensor_meta.original_shape), - tensor_meta.original_strides, - tensor_meta.storage_offset, - dtype=torch.uint8, # Not sure if this is correct - device=tensor_meta.device, - requires_grad=tensor_meta.requires_grad, - ) - return uint2i_tensor - - def __init__(self, input_tensor: torch.Tensor, **kwargs): - self.elem = input_tensor - - @classmethod - def from_packed(cls, unpacked): - return UInt2Tensor(pack_uint2(unpacked)) - - def tolist(self): - return unpack_uint2(self.elem).tolist() - - def __tensor_flatten__(self): - return ["elem"], None - - @staticmethod - def __tensor_unflatten__(flattened, meta): - assert meta is None - elem = flattened["elem"] - return UInt2Tensor(elem) - - def __hash__(self): - return hash(self.elem) - - def __eq__(self, other): - return torch.equal(self.elem, other.elem) - - def __repr__(self): - data = unpack_uint2(self.elem).tolist() - return f"UInt2Tensor({data}, dtype=torch.uint2)" - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - def allowed_subclasses(type): - return ( - issubclass(cls, type) - or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) - or issubclass( - torch._subclasses.functional_tensor.FunctionalTensor, type - ) - ) - - if not all(allowed_subclasses(t) for t in types): - return NotImplemented("Up to the next one to handle") - - if func in UINT2_OPS_TABLE: - return UINT2_OPS_TABLE[func](func, args, kwargs) - raise NotImplementedError( - f"UINT2 dispatch: attempting to run {func}, this is not supported" - ) - - -@implements([torch.ops.aten.view.default]) -def uint2_view(func, args, kwargs): - tensor, size = args - size = utils.infer_size(size, tensor.numel()) - assert not kwargs - dsize = down_size(size) - reshaped_elem = tensor.elem.view(dsize) - return UInt2Tensor(reshaped_elem) - - -@implements([torch.ops.aten.view.dtype]) -def view_dtype(func, args, kwargs): - tensor, dtype = args - if dtype is torch.uint8: - return unpack_uint2(tensor.elem).to(torch.uint8) - raise NotImplementedError(f"view {dtype} not supported") - - -@implements([torch.ops.aten.clone.default]) -def clone(func, args, kwargs): - tensor = args[0] - return UInt2Tensor(tensor.elem.clone()) - - -@implements([torch.ops.aten._unsafe_view.default]) -def unsafe_view(func, args, kwargs): - tensor, size = args - size = utils.infer_size(size, tensor.numel()) - assert not kwargs - dsize = down_size(size) - reshaped_elem = tensor.elem.view(dsize) - return UInt2Tensor(reshaped_elem) - - -@implements([torch.ops.aten.unbind.int]) -def unbind(func, args, kwargs): - tensor, dim = fill_defaults(args, 2, [0]) - if dim != tensor.dim() - 1: - raise NotImplementedError(f"unbind dim={dim}") - else: - x = tensor.elem.to(torch.uint8).unbind(dim) - return x - - -@implements([torch.ops.aten._to_copy.default]) -def to_copy(func, args, kwargs): - (tensor,) = args - dtype = kwargs["dtype"] - if dtype == torch.uint8: - return unpack_uint2(tensor.elem).view(tensor.shape).view(torch.uint8) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return tensor.to(torch.uint8).to(dtype) - elif isinstance(tensor, UInt2Tensor): - return tensor - raise NotImplementedError(f"to_copy {dtype} not supported") - - -@implements([torch.ops.aten.select.int]) -def select(func, args, kwargs): - tensor, dim, index = args - if dim != tensor.dim() - 1: - selected_elem = tensor.elem.select(dim, index) - return UInt2Tensor(selected_elem) - else: - raise NotImplementedError(f"select dim={dim}") - - -@implements([torch.ops.aten.reshape.default]) -def reshape(func, args, kwargs): - tensor, size = args - size = utils.infer_size(size, tensor.numel()) - assert not kwargs - dsize = down_size(size) - reshaped_elem = tensor.elem.view(dsize) - return UInt2Tensor(reshaped_elem) - - -def slice_tensor(func, args, kwargs): - tensor, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == tensor.dim() - 1: - if step != 1: - raise NotImplementedError(f"slice step={step}") - assert start % 4 == 0, start - assert end is None or end % 4 == 0, end - end = end if end is not None else tensor.shape[dim] - sliced_elem = tensor.elem[..., start // 4 : end // 4 : step] - return UInt2Tensor(sliced_elem) - else: - sliced_elem = tensor.elem[..., start:end:step] - return UInt2Tensor(sliced_elem) - - -@implements([torch.ops.aten.equal.default]) -def equal(func, args, kwargs): - tensor, other = args - return torch.equal(tensor.elem, other.elem) - - -@implements([torch.ops.aten.detach.default]) -def detach(func, args, kwargs): - (tensor,) = args - detached_elem = tensor.elem.detach() - return UInt2Tensor(detached_elem) - - -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, args, kwargs): - (tensor, dtype) = args - if dtype == torch.uint8: - return unpack_uint2(tensor.elem).view(torch.uint8) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return unpack_uint2(tensor.elem).to(torch.uint8).to(dtype) - elif isinstance(tensor, UInt2Tensor): - return tensor.elem - - raise NotImplementedError(f"to {dtype} not supported") - - -@implements([torch.ops.aten.t.default]) -def t(func, args, kwargs): - (tensor,) = args - unpacked = unpack_uint2(tensor.elem).to(tensor.device) - transposed = unpacked.t() - return UInt2Tensor(pack_uint2(transposed)) - - -@implements([torch.ops.aten.allclose.default]) -def allclose(func, args, kwargs): - tensor, other = args - return torch.allclose(tensor.elem, other.elem) diff --git a/torchao/prototype/dtypes/uintgen.py b/torchao/prototype/dtypes/uintgen.py deleted file mode 100644 index 192e4ad05a..0000000000 --- a/torchao/prototype/dtypes/uintgen.py +++ /dev/null @@ -1,385 +0,0 @@ -import torch - -""" -Contains generic functions to pack and unpack uintx (2-7) tensors into uint8 tensors. -""" - - -def down_size_uint2(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" - return (*size[:-1], size[-1] // 4) - - -def up_size_uint2(size): - return (*size[:-1], size[-1] * 4) - - -def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = (uint8_data >> 6) & 0b11 - second_elements = (uint8_data >> 4) & 0b11 - third_elements = (uint8_data >> 2) & 0b11 - fourth_elements = uint8_data & 0b11 - return torch.stack( - (first_elements, second_elements, third_elements, fourth_elements), dim=-1 - ).view(up_size_uint2(shape)) - - -def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - """pack lowest 2 bits of 2 uint8 -> 1 uint8""" - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = ( - (uint8_data[::4] & 0b11) << 6 - | (uint8_data[1::4] & 0b11) << 4 - | (uint8_data[2::4] & 0b11) << 2 - | (uint8_data[3::4] & 0b11) - ).view(down_size_uint2(shape)) - return packed_data - - -def down_size_uint3(size): - assert size[-1] % 8 == 0, f"{size} last dim not divisible by eight" - return (*size[:-1], size[-1] // 8 * 3) - - -def up_size_uint3(size): - assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" - return (*size[:-1], size[-1] // 3 * 8) - - -def unpack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: - """ - 3 -> 8 - 01234567|01234567|01234567 - AAABBBCC|CDDDEEEF|FFGGGHHH - """ - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - - return torch.stack( - ( - (uint8_data[::3] >> 5) & 0b111, - (uint8_data[::3] >> 2) & 0b111, - (uint8_data[::3] & 0b11) << 1 | (uint8_data[1::3] >> 7) & 0b1, - (uint8_data[1::3] >> 4) & 0b111, - (uint8_data[1::3] >> 1) & 0b111, - (uint8_data[1::3] & 0b1) << 2 | (uint8_data[2::3] >> 6) & 0b11, - (uint8_data[2::3] >> 3) & 0b111, - uint8_data[2::3] & 0b111, - ), - dim=-1, - ).view(up_size_uint3(shape)) - - -def pack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: - """ - 8 -> 3 - 01234567|01234567|01234567 - AAABBBCC|CDDDEEEF|FFGGGHHH - """ - - shape = uint8_data.shape - assert shape[-1] % 8 == 0 - uint8_data = uint8_data.contiguous().view(-1) - - packed_data = torch.stack( - ( - ( - (uint8_data[::8] & 0b111) << 5 - | (uint8_data[1::8] & 0b111) << 2 - | (uint8_data[2::8] & 0b111) >> 1 - ), - ( - (uint8_data[2::8] & 0b1) << 7 - | (uint8_data[3::8] & 0b111) << 4 - | (uint8_data[4::8] & 0b111) << 1 - | ((uint8_data[5::8] >> 2) & 1) - ), - ( - (uint8_data[5::8] & 0b11) << 6 - | (uint8_data[6::8] & 0b111) << 3 - | (uint8_data[7::8] & 0b111) - ), - ), - dim=-1, - ).view(down_size_uint3(shape)) - - return packed_data - - -def down_size_uint4(size): - assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" - return (*size[:-1], size[-1] // 2) - - -def up_size_uint4(size): - return (*size[:-1], size[-1] * 2) - - -def unpack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = (uint8_data >> 4) & 0b1111 - second_elements = uint8_data & 0b1111 - return torch.stack((first_elements, second_elements), dim=-1).view( - up_size_uint4(shape) - ) - - -def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - assert shape[-1] % 2 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = (uint8_data[::2] << 4 | (uint8_data[1::2] & 0b1111)).view( - down_size_uint4(shape) - ) - return packed_data - - -def down_size_uint5(size): - assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" - return (*size[:-1], size[-1] // 8 * 5) - - -def up_size_uint5(size): - assert size[-1] % 5 == 0, f"{size} last dim not divisible by 5" - return (*size[:-1], size[-1] // 5 * 8) - - -def pack_uint5(uint8_data: torch.Tensor) -> torch.Tensor: - """Pack the 5 lowest bits of 8 input bytes into 5 bytes - - 8 -> 5 - 01234567|01234567|01234567|01234567|01234567 - AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH - - The packing pattern: - - First byte: (A0 A1 A2 A3 A4 B0 B1 B2) - - Second byte: (B3 B4 C0 C1 C2 C3 C4 D0) - - Third byte: (D1 D2 D3 D4 E0 E1 E2 E3) - - Fourth byte: (E4 F0 F1 F2 F3 F4 G0 G1) - - Fifth byte: (G2 G3 G4 H0 H1 H2 H3 H4) - """ - shape = uint8_data.shape - assert ( - shape[-1] % 8 == 0 - ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" - - uint8_data = uint8_data.contiguous().view(-1, 8) - - packed_data = torch.stack( - ( - ((uint8_data[:, 0] & 0b00011111) << 3) - | ((uint8_data[:, 1] & 0b00011100) >> 2), - ((uint8_data[:, 1] & 0b00000011) << 6) - | ((uint8_data[:, 2] & 0b00011111) << 1) - | ((uint8_data[:, 3] & 0b10000) >> 4), - ((uint8_data[:, 3] & 0b00001111) << 4) - | ((uint8_data[:, 4] & 0b00011110) >> 1), - ((uint8_data[:, 4] & 0b00000001) << 7) - | ((uint8_data[:, 5] & 0b00011111) << 2) - | ((uint8_data[:, 6] & 0b0011000) >> 3), - ((uint8_data[:, 6] & 0b00000111) << 5) | (uint8_data[:, 7] & 0b00011111), - ), - dim=-1, - ).view(down_size_uint5(shape)) - - return packed_data - - -def unpack_uint5(packed_data: torch.Tensor) -> torch.Tensor: - """Unpack the 5 bytes into the 5 lowest bits of 8 bytes - 01234567|01234567|01234567|01234567|01234567 - AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH - """ - shape = packed_data.shape - assert ( - shape[-1] % 5 == 0 - ), f"Input last dimension should be divisible by 5, but got {shape[-1]}" - - packed_data = packed_data.contiguous().view(-1, 5) - - unpacked_data = torch.stack( - ( - ((packed_data[:, 0] >> 3) & 0b00011111), - ((packed_data[:, 0] & 0b00000111) << 2) - | ((packed_data[:, 1] >> 6) & 0b00000011), - ((packed_data[:, 1] >> 1) & 0b00011111), - ((packed_data[:, 1] & 0b00000001) << 4) - | ((packed_data[:, 2] >> 4) & 0b00001111), - ((packed_data[:, 2] & 0b00001111) << 1) - | ((packed_data[:, 3] >> 7) & 0b00000001), - ((packed_data[:, 3] >> 2) & 0b00011111), - ((packed_data[:, 3] & 0b00000011) << 3) - | ((packed_data[:, 4] >> 5) & 0b00000111), - packed_data[:, 4] & 0b00011111, - ), - dim=-1, - ).view(up_size_uint5(shape)) - - return unpacked_data - - -def down_size_uint6(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" - return (*size[:-1], size[-1] // 4 * 3) - - -def up_size_uint6(size): - assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" - return (*size[:-1], size[-1] // 3 * 4) - - -def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - """Pack the 6 lowest bits of 4 input bytes into 3 bytes - - 4 -> 3 - 01234567|01234567|01234567 - AAAAAABB|BBBBCCCC|CCDDDDDD - - The packing pattern: - - First byte: (A0 A1 A2 A3 A4 A5 B0 B1) - - Second byte: (B2 B3 B4 B5 C0 C1 C2 C3) - - Third byte: (C4 C5 D0 D1 D2 D3 D4 D5) - """ - shape = uint8_data.shape - assert ( - shape[-1] % 4 == 0 - ), f"Input last dimension should be divisible by 4, but got {shape[-1]}" - - uint8_data = uint8_data.contiguous().view(-1, 4) - - packed_data = torch.stack( - ( - ((uint8_data[:, 0] & 0b00111111) << 2) - | ((uint8_data[:, 1] >> 4) & 0b00000011), - ((uint8_data[:, 1] & 0b00001111) << 4) - | ((uint8_data[:, 2] >> 2) & 0b00001111), - ((uint8_data[:, 2] & 0b00000011) << 6) | (uint8_data[:, 3] & 0b00111111), - ), - dim=-1, - ).view(down_size_uint6(shape)) - - return packed_data - - -def unpack_uint6(packed_data: torch.Tensor) -> torch.Tensor: - """Unpack the 3 bytes into the 6 lowest bits of 4 outputs - 01234567|01234567|01234567 - AAAAAABB|BBBBCCCC|CCDDDDDD - """ - shape = packed_data.shape - assert ( - shape[-1] % 3 == 0 - ), f"Input last dimension should be divisible by 3, but got {shape[-1]}" - - packed_data = packed_data.contiguous().view(-1, 3) - - unpacked_data = torch.stack( - ( - (packed_data[:, 0] >> 2) & 0b00111111, - ((packed_data[:, 0] & 0b00000011) << 4) - | ((packed_data[:, 1] >> 4) & 0b00001111), - ((packed_data[:, 1] & 0b00001111) << 2) - | ((packed_data[:, 2] >> 6) & 0b00000011), - packed_data[:, 2] & 0b00111111, - ), - dim=-1, - ).view(up_size_uint6(shape)) - - return unpacked_data - - -def down_size_uint7(size): - assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" - return (*size[:-1], size[-1] // 8 * 7) - - -def up_size_uint7(size): - assert size[-1] % 7 == 0, f"{size} last dim not divisible by 7" - return (*size[:-1], size[-1] // 7 * 8) - - -def pack_uint7(uint8_data: torch.Tensor) -> torch.Tensor: - """Pack the 7 lowest bits of 8 input bytes into 7 bytes - - 8 -> 7 - 01234567|01234567|01234567|01234567|01234567|01234567|01234567 - AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH - - The packing pattern: - - First byte: (A0 A1 A2 A3 A4 A5 A6 B0) - - Second byte: (B1 B2 B3 B4 B5 B6 C0 C1) - - Third byte: (C2 C3 C4 C5 C6 D0 D1 D2) - - Fourth byte: (D3 D4 D5 D6 E0 E1 E2 E3) - - Fifth byte: (E4 E5 E6 F0 F1 F2 F3 F4) - - Sixth byte: (F5 F6 G0 G1 G2 G3 G4 G5) - - Seventh byte:(G6 H0 H1 H2 H3 H4 H5 H6) - """ - shape = uint8_data.shape - assert ( - shape[-1] % 8 == 0 - ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" - - uint8_data = uint8_data.contiguous().view(-1, 8) - - packed_data = torch.stack( - ( - ((uint8_data[:, 0] & 0b01111111) << 1) - | ((uint8_data[:, 1] >> 6) & 0b00000001), - ((uint8_data[:, 1] & 0b00111111) << 2) - | ((uint8_data[:, 2] >> 5) & 0b00000011), - ((uint8_data[:, 2] & 0b00011111) << 3) - | ((uint8_data[:, 3] >> 4) & 0b00000111), - ((uint8_data[:, 3] & 0b00001111) << 4) - | ((uint8_data[:, 4] >> 3) & 0b00001111), - ((uint8_data[:, 4] & 0b00000111) << 5) - | ((uint8_data[:, 5] >> 2) & 0b00011111), - ((uint8_data[:, 5] & 0b00000011) << 6) - | ((uint8_data[:, 6] >> 1) & 0b00111111), - ((uint8_data[:, 6] & 0b00000001) << 7) - | ((uint8_data[:, 7] >> 0) & 0b01111111), - ), - dim=-1, - ).view(down_size_uint7(shape)) - - return packed_data - - -def unpack_uint7(packed_data: torch.Tensor) -> torch.Tensor: - """Unpack the 7 bytes into the 7 lowest bits of 8 bytes - 01234567|01234567|01234567|01234567|01234567|01234567|01234567 - AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH - """ - shape = packed_data.shape - assert ( - shape[-1] % 7 == 0 - ), f"Input last dimension should be divisible by 7, but got {shape[-1]}" - - packed_data = packed_data.contiguous().view(-1, 7) - - unpacked_data = torch.stack( - ( - (packed_data[:, 0] >> 1) & 0b01111111, - ((packed_data[:, 0] & 0b00000001) << 6) - | ((packed_data[:, 1] >> 2) & 0b01111111), - ((packed_data[:, 1] & 0b00000011) << 5) - | ((packed_data[:, 2] >> 3) & 0b01111111), - ((packed_data[:, 2] & 0b00000111) << 4) - | ((packed_data[:, 3] >> 4) & 0b01111111), - ((packed_data[:, 3] & 0b00001111) << 3) - | ((packed_data[:, 4] >> 5) & 0b01111111), - ((packed_data[:, 4] & 0b00011111) << 2) - | ((packed_data[:, 5] >> 6) & 0b01111111), - ((packed_data[:, 5] & 0b00111111) << 1) - | ((packed_data[:, 6] >> 7) & 0b01111111), - packed_data[:, 6] & 0b01111111, - ), - dim=-1, - ).view(up_size_uint7(shape)) - - return unpacked_data