From be78a081eb2000e32d483fa994cbd3c6ab6620f3 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 24 Jul 2024 02:56:39 -0400 Subject: [PATCH 1/5] initial flow for autoround Signed-off-by: yiliu30 --- .../prototype/autoround/auto_round_flow.py | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 torchao/prototype/autoround/auto_round_flow.py diff --git a/torchao/prototype/autoround/auto_round_flow.py b/torchao/prototype/autoround/auto_round_flow.py new file mode 100644 index 0000000000..268882c584 --- /dev/null +++ b/torchao/prototype/autoround/auto_round_flow.py @@ -0,0 +1,130 @@ +# ==------------------------------------------------------------------------------------------== +# TorchAO +# ==------------------------------------------------------------------------------------------== +from typing import Optional, Callable, Any, List +import torch +import torchao.quantization as ao_quant + + +def create_qmodel_from_qdq_model(qdq_model): + # TODO: simplify this process by creating a new class at unwrapper stage + def _is_quantized_linear(model, fqn): + return hasattr(model, "scale") + + def create_qlinear(linear): + def _get_qdq_data(linear): + # TODO: below is a fake implementation + int_data = linear.weight() + scales_and_zeros = [(linear.scale, linear.zero_point)] + return int_data, scales_and_zeros + + int_data, scales_and_zeros = _get_qdq_data(linear) + # TODO: below is a fake implementation, need more quantization info to dispatch the right quantization class + woq_linear = ao_quant.Int4WeightOnlyQuantizedLinearWeight( + int_data, + scales_and_zeros, + transposed=False, + shape=linear.weight.shape, + groupsize=128, + inner_k_tiles=32, + dtype=linear.weight.dtype, + ) + return woq_linear + + qmodel = ao_quant.quant_api._replace_with_custom_fn_if_matches_filter( + qdq_model, create_qlinear, _is_quantized_linear + ) + return qmodel + + +class ModuleInputCapture(torch.nn.Module): + """Capture the input of the given module.""" + + def __init__(self): + super().__init__() + self.inputs: List[Any] = [] + + def forward(self, *args, **kwarsg): + self.inputs.append((args, kwarsg)) + + +class ObservedBlock(torch.nn.Module): + def __init__(self, float_block: torch.nn.Module, block_observer: ModuleInputCapture): + super().__init__() + # e.g., replace `transformers.models.llama.modeling_llama.LlamaDecoderLayer` + self.float_block = float_block + self.block_observer = block_observer + + def forward(self, *args, **kwarsg): + self.block_observer(*args, **kwarsg) + # Here we not really run the forward of float_block, but just capture the input of the block. + # We run the forward of the float_block in the `apply_auto_round` function,as we may only + # sample partial of the inputs. + + @classmethod + def from_float_block(cls, float_block): + # TODO: should we pass the float_block to `ModuleInputCapture`? + block_observer = ModuleInputCapture() + return cls(float_block, block_observer) + + def get_module_inputs(self): + # TODO: concat all inputs + inputs = self.block_observer.inputs + return inputs + + +def insert_observers_for_block_( + model: torch.nn.Module, + block_observer: ModuleInputCapture, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, +) -> ObservedBlock: + replacement_fn = lambda m: ObservedBlock.from_float(m, block_observer) + return ao_quant.quant_api._replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn) + + +def apply_auto_round(observed_block: ObservedBlock): + # Call the autoround to execute the optimization process + import auto_round + + # Start the training process to update the v and alpha and betta. + rounder = auto_round.AutoRound( + model=observed_block, + tokenizer=None, + bits=4, + iters=2, + use_quant_input=False, # disable it for now + n_samples=2, # double-check it + amp=False, + ) + inputs = observed_block.get_module_inputs() + # TODO: rename the `quant_block` to `quant_block_` + rounder.quant_block(observed_block, input_ids=inputs["input_ids"], input_others=inputs["input_others"]) + return create_qmodel_from_qdq_model(observed_block) + + +# ==------------------------------------------------------------------------------------------== +# The Modeling User API +# ==------------------------------------------------------------------------------------------== + +# Step 0. Load the float model +import transformers +pretrained_model_name_or_path = "facebook/opt-125m" +model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) + +# Step 1. replace the block with an observed block +# Similar with the `insert_observers_`, but for block + +is_block = lambda model, fqn: isinstance(model, transformers.models.opt.modeling_opt.OPTDecoderLayer) +block_observer = ModuleInputCapture() +insert_observers_for_block_(model, block_observer, is_block) + +# Step 2. calibrating / training +# For capturing the input of block +batch_size, seq_len, hidden_size = 2, 10, 768 +example_inputs = torch.rannd((batch_size, seq_len, hidden_size)) +for _ in range(10): + model(*example_inputs) + +# Step 3. quantize the block +is_observed_block = lambda model, fqn: isinstance(model, ObservedBlock) +ao_quant.quantize_(model, apply_auto_round, is_observed_block) From 49f8075ab94335c9b9f23df0ac5cfea54155ff07 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 25 Jul 2024 04:43:16 -0400 Subject: [PATCH 2/5] update flow Signed-off-by: yiliu30 --- .../prototype/autoround/auto_round_flow.py | 290 +++++++++++++++--- 1 file changed, 244 insertions(+), 46 deletions(-) diff --git a/torchao/prototype/autoround/auto_round_flow.py b/torchao/prototype/autoround/auto_round_flow.py index 268882c584..7932796743 100644 --- a/torchao/prototype/autoround/auto_round_flow.py +++ b/torchao/prototype/autoround/auto_round_flow.py @@ -1,9 +1,63 @@ +seed = 0 +import random + +random.seed(seed) +import torch + +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +import numpy as np + +np.random.seed(seed) +from typing import Optional, Callable, Any, List, Tuple, Dict + + +def assert_same( + a: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]], + b: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]], +): + assert len(a) == len(b), f"len: {len(a)} != {len(b)}" + for i, _ in enumerate(a): + assert type(a[i]) == type(b[i]), f"type: {type(a[i])} != {type(b[i])}" + if isinstance(a[i], torch.Tensor): + torch.testing.assert_allclose(a[i], b[i]) + elif isinstance(a[i], tuple): + assert_same(a[i], b[i]) + elif isinstance(a[i], dict): + for k in a[i].keys(): + assert k in b[i], f"key: {k} not in {b[i]}" + assert_same(a[i][k], b[i].get(k)) + elif a[i] is None: + assert b[i] is None + else: + raise ValueError(f"Unsupported type: {type(a[i])}") + print("Same!") + + +def inspect_module_inputs(inputs, indent=""): + if isinstance(inputs, torch.Tensor): + print(f"{indent}Tensor: {inputs.shape}") + elif isinstance(inputs, tuple) or isinstance(inputs, list): + for i in inputs: + inspect_module_inputs(i, indent + " ") + elif isinstance(inputs, dict): + for k, v in inputs.items(): + print(f"{indent}{k}:") + inspect_module_inputs(v, indent + " ") + elif inputs is None: + print(f"{indent}None") + else: + print(f"{indent}{type(inputs)}") + + # ==------------------------------------------------------------------------------------------== # TorchAO # ==------------------------------------------------------------------------------------------== -from typing import Optional, Callable, Any, List + import torch import torchao.quantization as ao_quant +from functools import partial +import transformers def create_qmodel_from_qdq_model(qdq_model): @@ -13,14 +67,46 @@ def _is_quantized_linear(model, fqn): def create_qlinear(linear): def _get_qdq_data(linear): - # TODO: below is a fake implementation - int_data = linear.weight() - scales_and_zeros = [(linear.scale, linear.zero_point)] - return int_data, scales_and_zeros + # TODO: refine the impl, refer: https://github.com/yiliu30/auto-round/pull/2 + qdq_weight = linear.weight + scales = linear.scale.reshape(-1, 1) + zeros = linear.zp.reshape(-1, 1) + orig_shape = qdq_weight.shape + # breakpoint() + group_size = 128 + gs_shape = (-1, group_size) + qweight_with_origin_shape = (qdq_weight.reshape(gs_shape) / scales + zeros).round().reshape(orig_shape) + qweight_with_origin_shape = qweight_with_origin_shape.to(torch.int32) + qweight = qweight_with_origin_shape + + scales = scales.reshape(qweight.shape[0], -1) + zp = zeros.reshape(qweight.shape[0], -1) + + n, k = qweight.shape + assert ( + scales.shape == torch.Size([qweight.shape[0], qweight.shape[1] // group_size]) + ), f"expect scales shape {torch.Size([qweight.shape[0], qweight.shape[1] // group_size])}, but got {scales.shape}" + if zp is not None: + zp = zp.to(torch.bfloat16) + zp = zp.reshape(qweight.shape[0], -1) + assert ( + zp.shape == torch.Size([qweight.shape[0], qweight.shape[1] // group_size]) + ), f"expect zp shape {torch.Size([qweight.shape[0], qweight.shape[1] // group_size])}, but got {zp.shape}" + zeros = (8 - zp) * scales + + # Hard code inner_k_tiles = 2 + inner_k_tiles = 2 + # Pack to tinygemm reqiured format + packed_q = torch.ops.aten._convert_weight_to_int4pack(qweight, inner_k_tiles) + scales_and_zeros = ao_quant.utils.pack_tinygemm_scales_and_zeros( + scales.to(torch.bfloat16), zeros.to(torch.bfloat16) + ) + q_groups = k // group_size + return packed_q, scales_and_zeros int_data, scales_and_zeros = _get_qdq_data(linear) - # TODO: below is a fake implementation, need more quantization info to dispatch the right quantization class - woq_linear = ao_quant.Int4WeightOnlyQuantizedLinearWeight( + # TODO: Double check below args + woq_weight = ao_quant.Int4WeightOnlyQuantizedLinearWeight( int_data, scales_and_zeros, transposed=False, @@ -29,7 +115,10 @@ def _get_qdq_data(linear): inner_k_tiles=32, dtype=linear.weight.dtype, ) - return woq_linear + linear.weight = torch.nn.Parameter(woq_weight, requires_grad=False) + del linear.scale + del linear.zp + return linear qmodel = ao_quant.quant_api._replace_with_custom_fn_if_matches_filter( qdq_model, create_qlinear, _is_quantized_linear @@ -47,84 +136,193 @@ def __init__(self): def forward(self, *args, **kwarsg): self.inputs.append((args, kwarsg)) + def __repr__(self): + return f"ModuleInputCapture(inputs: {len(self.inputs)})" + class ObservedBlock(torch.nn.Module): - def __init__(self, float_block: torch.nn.Module, block_observer: ModuleInputCapture): + def __init__(self, float_block: torch.nn.Module, block_observer: ModuleInputCapture, input_hook_handle=None): super().__init__() # e.g., replace `transformers.models.llama.modeling_llama.LlamaDecoderLayer` self.float_block = float_block self.block_observer = block_observer + self.input_hook_handle = input_hook_handle + + def remove_input_hook_handle(self): + self.input_hook_handle.remove() def forward(self, *args, **kwarsg): - self.block_observer(*args, **kwarsg) - # Here we not really run the forward of float_block, but just capture the input of the block. - # We run the forward of the float_block in the `apply_auto_round` function,as we may only - # sample partial of the inputs. + return self.float_block(*args, **kwarsg) @classmethod - def from_float_block(cls, float_block): - # TODO: should we pass the float_block to `ModuleInputCapture`? - block_observer = ModuleInputCapture() - return cls(float_block, block_observer) + def from_float(cls, float_block: torch.nn.Module, block_observer: ModuleInputCapture = None): + # TODO: only insert hook to the float_block to capture the input and save it to the block_observer + # TODO: look like no need new module for it? + def capture_inputs_hook( + block_observer: ModuleInputCapture, + module: torch.nn.Module, + args: Tuple[torch.Tensor], + kwargs: Dict[str, Any], + ) -> Tuple[Any, Any]: + block_observer.inputs.append((args, kwargs)) + return args, kwargs + + if block_observer is None: + block_observer = ModuleInputCapture() + hook_handle = float_block.register_forward_pre_hook( + partial(capture_inputs_hook, block_observer), with_kwargs=True + ) + return cls(float_block, block_observer, hook_handle) def get_module_inputs(self): - # TODO: concat all inputs inputs = self.block_observer.inputs return inputs def insert_observers_for_block_( model: torch.nn.Module, - block_observer: ModuleInputCapture, filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ) -> ObservedBlock: - replacement_fn = lambda m: ObservedBlock.from_float(m, block_observer) + replacement_fn = lambda m: ObservedBlock.from_float(m) return ao_quant.quant_api._replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn) def apply_auto_round(observed_block: ObservedBlock): + block_inputs = observed_block.block_observer.inputs + # first_inputs = block_inputs[0][1] + # hidden_states = first_inputs["hidden_states"] + # position_ids = first_inputs["position_ids"] + # attention_mask = first_inputs["attention_mask"] + + # # WA for now + # _input_ids = hidden_states + # _input_others = {"positional_inputs": position_ids, "attention_mask": attention_mask} + + _input_ids = block_inputs[0][0][0].detach() + position_ids = [] + attention_mask = block_inputs[0][1]["attention_mask"].detach() + _input_others = {"positional_inputs": position_ids, "attention_mask": attention_mask} + + seq_len = _input_ids.shape[1] # Call the autoround to execute the optimization process import auto_round + block = observed_block.float_block + block.dtype = next(block.parameters()).dtype + # Start the training process to update the v and alpha and betta. rounder = auto_round.AutoRound( - model=observed_block, + model=block, tokenizer=None, bits=4, iters=2, use_quant_input=False, # disable it for now - n_samples=2, # double-check it + n_samples=1, # double-check it amp=False, + seqlen=seq_len, + batch_size=1, ) - inputs = observed_block.get_module_inputs() # TODO: rename the `quant_block` to `quant_block_` - rounder.quant_block(observed_block, input_ids=inputs["input_ids"], input_others=inputs["input_others"]) + rounder.quant_block(block, input_ids=_input_ids, input_others=_input_others) return create_qmodel_from_qdq_model(observed_block) +# ==------------------------------------------------------------------------------------------== +# Tests +# ==------------------------------------------------------------------------------------------== + + +import pytest + + +class TestFlow: + @torch.no_grad() + def test_obseverblock(self): + model = transformers.AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-GPTJForCausalLM") + block = model.transformer.h[0] + observed_block = ObservedBlock.from_float(block) + bs, seq_len, hidden_size = 2, 3, 32 + hidden_states = torch.randn((bs, seq_len, hidden_size)) + position_ids = torch.randint(0, seq_len, (bs, seq_len)) + # Record the input and output of the block + origin_output = [] + out1 = observed_block(hidden_states, position_ids=position_ids) + origin_output.append(out1) + print(observed_block.block_observer.inputs) + attention_mask = torch.randn(bs, 4, seq_len, seq_len) + out2 = observed_block(hidden_states, None, attention_mask, position_ids=position_ids) + origin_output.append(out2) + print(observed_block.block_observer.inputs) + observed_block.remove_input_hook_handle() + # Replay + new_output = [] + for args, kwargs in observed_block.block_observer.inputs: + out = observed_block(*args, **kwargs) + new_output.append(out) + assert_same(origin_output, new_output) + + def test_with_gptj(self): + with torch.no_grad(): + # Step 0. Load the float model + import transformers + + # pretrained_model_name_or_path = "hf-internal-testing/tiny-random-GPTJForCausalLM" + pretrained_model_name_or_path = "facebook/opt-125m" + model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) + tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + + # Step 1. replace the block with an observed block + # Similar with the `insert_observers_`, but for block + is_block = lambda model, fqn: isinstance(model, transformers.models.opt.modeling_opt.OPTDecoderLayer) + # is_block = lambda model, fqn: isinstance(model, transformers.models.gptj.modeling_gptj.GPTJBlock) + block_observer = ModuleInputCapture() + insert_observers_for_block_(model, is_block) + + print(f"model with observer (before calibration): \n{model}") + + # Step 2. calibrating / training + # For capturing the input of block + # batch_size, seq_len, hidden_size = 2, 5, 32 + iters = 4 + prompt = "The meaning of life is" + # "input_ids", "attention_mask" + example_inputs = tokenizer(prompt, return_tensors="pt") + for _ in range(iters): + model(**example_inputs) + + print(f"model with observer (after calibration): \n{model}") + + # Step 3. quantize the block + is_observed_block = lambda model, fqn: isinstance(model, ObservedBlock) + ao_quant.quantize_(model, apply_auto_round, is_observed_block) + + # ==------------------------------------------------------------------------------------------== # The Modeling User API # ==------------------------------------------------------------------------------------------== -# Step 0. Load the float model -import transformers -pretrained_model_name_or_path = "facebook/opt-125m" -model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) - -# Step 1. replace the block with an observed block -# Similar with the `insert_observers_`, but for block - -is_block = lambda model, fqn: isinstance(model, transformers.models.opt.modeling_opt.OPTDecoderLayer) -block_observer = ModuleInputCapture() -insert_observers_for_block_(model, block_observer, is_block) - -# Step 2. calibrating / training -# For capturing the input of block -batch_size, seq_len, hidden_size = 2, 10, 768 -example_inputs = torch.rannd((batch_size, seq_len, hidden_size)) -for _ in range(10): - model(*example_inputs) - -# Step 3. quantize the block -is_observed_block = lambda model, fqn: isinstance(model, ObservedBlock) -ao_quant.quantize_(model, apply_auto_round, is_observed_block) + +def test_user_api(): + # Step 0. Load the float model + import transformers + + pretrained_model_name_or_path = "facebook/opt-125m" + model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) + + # Step 1. replace the block with an observed block + # Similar with the `insert_observers_`, but for block + + is_block = lambda model, fqn: isinstance(model, transformers.models.opt.modeling_opt.OPTDecoderLayer) + insert_observers_for_block_(model, is_block) + + # Step 2. calibrating / training + # For capturing the input of block + batch_size, seq_len, hidden_size = 2, 10, 768 + example_inputs = torch.rannd((batch_size, seq_len, hidden_size)) + + for _ in range(10): + model(*example_inputs) + + # Step 3. quantize the block + is_observed_block = lambda model, fqn: isinstance(model, ObservedBlock) + ao_quant.quantize_(model, apply_auto_round, is_observed_block) From 62834a22498a91218d72fe83ed7c7e59d5ac99b1 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 26 Jul 2024 04:54:10 -0400 Subject: [PATCH 3/5] use int4 kernel Signed-off-by: yiliu30 --- .../prototype/autoround/auto_round_flow.py | 97 +++++++++++++------ torchao/quantization/subclass.py | 8 +- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/torchao/prototype/autoround/auto_round_flow.py b/torchao/prototype/autoround/auto_round_flow.py index 7932796743..56d52f6a93 100644 --- a/torchao/prototype/autoround/auto_round_flow.py +++ b/torchao/prototype/autoround/auto_round_flow.py @@ -1,3 +1,8 @@ +# ==------------------------------------------------------------------------------------------== +# Utils +# ==------------------------------------------------------------------------------------------== +import os + seed = 0 import random @@ -65,46 +70,67 @@ def create_qmodel_from_qdq_model(qdq_model): def _is_quantized_linear(model, fqn): return hasattr(model, "scale") + @torch.no_grad() def create_qlinear(linear): def _get_qdq_data(linear): # TODO: refine the impl, refer: https://github.com/yiliu30/auto-round/pull/2 - qdq_weight = linear.weight - scales = linear.scale.reshape(-1, 1) - zeros = linear.zp.reshape(-1, 1) + # qdq_weight shape: (m, k) + qdq_weight = linear.weight.clone() + device = qdq_weight.device + # scales, zeros shape: (m, n_groups) + scales = linear.scale.to(device) + zeros = linear.zp.to(device) + + # Requantize the qdqweight to get the int_data orig_shape = qdq_weight.shape - # breakpoint() - group_size = 128 - gs_shape = (-1, group_size) - qweight_with_origin_shape = (qdq_weight.reshape(gs_shape) / scales + zeros).round().reshape(orig_shape) - qweight_with_origin_shape = qweight_with_origin_shape.to(torch.int32) - qweight = qweight_with_origin_shape + oc, ic = orig_shape + groupsize = linear.group_size + assert ic % groupsize == 0, f"expect k % groupsize == 0, but got {ic % groupsize}" + n_groups = ic // groupsize - scales = scales.reshape(qweight.shape[0], -1) - zp = zeros.reshape(qweight.shape[0], -1) + # Check the shapes of scales and zeros with int_data + scales_zeros_expected_shape = torch.Size([oc, n_groups]) + assert ( + scales.shape == scales_zeros_expected_shape + ), f"expect scales shape {scales_zeros_expected_shape}, but got {scales.shape}" - n, k = qweight.shape assert ( - scales.shape == torch.Size([qweight.shape[0], qweight.shape[1] // group_size]) - ), f"expect scales shape {torch.Size([qweight.shape[0], qweight.shape[1] // group_size])}, but got {scales.shape}" - if zp is not None: - zp = zp.to(torch.bfloat16) - zp = zp.reshape(qweight.shape[0], -1) - assert ( - zp.shape == torch.Size([qweight.shape[0], qweight.shape[1] // group_size]) - ), f"expect zp shape {torch.Size([qweight.shape[0], qweight.shape[1] // group_size])}, but got {zp.shape}" - zeros = (8 - zp) * scales + zeros.shape == scales_zeros_expected_shape + ), f"expect zeros shape {scales_zeros_expected_shape}, but got {zeros.shape}" + + flatten_scales = scales.reshape(-1, 1) + flatten_zeros = zeros.reshape(-1, 1) + gs_shape = (-1, groupsize) + int_data = ( + qdq_weight.reshape(gs_shape) + .div(flatten_scales) + .add(flatten_zeros) + .round() + .reshape(orig_shape) + .to(torch.int32) + ) + # Shift the zeros to align with tinnygemm + # TODO: more notes or discard this step + zeros = (8 - zeros) * scales + + # Pack to tinygemm reqiured format # Hard code inner_k_tiles = 2 inner_k_tiles = 2 - # Pack to tinygemm reqiured format - packed_q = torch.ops.aten._convert_weight_to_int4pack(qweight, inner_k_tiles) + + packed_int_data = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles) scales_and_zeros = ao_quant.utils.pack_tinygemm_scales_and_zeros( scales.to(torch.bfloat16), zeros.to(torch.bfloat16) ) - q_groups = k // group_size - return packed_q, scales_and_zeros + return packed_int_data, scales_and_zeros + # For Debug int_data, scales_and_zeros = _get_qdq_data(linear) + float_out = None + random_input = torch.randn((4, linear.weight.shape[1]), dtype=torch.bfloat16).to(int_data.device) + if os.environ.get("DEBUG", "0") == "1": + float_out = linear(random_input.to(linear.weight.dtype)) + # TODO: Double check below args woq_weight = ao_quant.Int4WeightOnlyQuantizedLinearWeight( int_data, @@ -113,11 +139,18 @@ def _get_qdq_data(linear): shape=linear.weight.shape, groupsize=128, inner_k_tiles=32, - dtype=linear.weight.dtype, + dtype=torch.bfloat16, ) linear.weight = torch.nn.Parameter(woq_weight, requires_grad=False) del linear.scale del linear.zp + if os.environ.get("DEBUG", "0") == "1": + new_lin_int_out = linear(random_input) + # _max = (float_out - new_lin_int_out).abs().max() + # print(f"Max diff between float and int: {_max}") + assert torch.allclose( + float_out.to(new_lin_int_out.dtype), new_lin_int_out, atol=2e-1 + ), f"The max diff between float and int is too large: {(float_out - new_lin_int_out).abs().max()}" return linear qmodel = ao_quant.quant_api._replace_with_custom_fn_if_matches_filter( @@ -214,6 +247,7 @@ def apply_auto_round(observed_block: ObservedBlock): rounder = auto_round.AutoRound( model=block, tokenizer=None, + sym=False, # Both True and False are OK bits=4, iters=2, use_quant_input=False, # disable it for now @@ -221,9 +255,10 @@ def apply_auto_round(observed_block: ObservedBlock): amp=False, seqlen=seq_len, batch_size=1, + low_gpu_mem_usage=False, ) # TODO: rename the `quant_block` to `quant_block_` - rounder.quant_block(block, input_ids=_input_ids, input_others=_input_others) + rounder.quant_block(block, input_ids=_input_ids, input_others=_input_others, device=torch.device("cuda")) return create_qmodel_from_qdq_model(observed_block) @@ -261,14 +296,16 @@ def test_obseverblock(self): new_output.append(out) assert_same(origin_output, new_output) - def test_with_gptj(self): + def test_with_opt(self): with torch.no_grad(): + device = torch.device("cuda") + # Step 0. Load the float model import transformers # pretrained_model_name_or_path = "hf-internal-testing/tiny-random-GPTJForCausalLM" pretrained_model_name_or_path = "facebook/opt-125m" - model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) + model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).to(device) tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # Step 1. replace the block with an observed block @@ -288,7 +325,7 @@ def test_with_gptj(self): # "input_ids", "attention_mask" example_inputs = tokenizer(prompt, return_tensors="pt") for _ in range(iters): - model(**example_inputs) + model(**example_inputs.to(device)) print(f"model with observer (after calibration): \n{model}") diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index a2801a622f..18751302b3 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -126,6 +126,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs if func is torch.nn.functional.linear: + breakpoint() mat1, w_qtensor, bias = ( args[0], args[1], @@ -441,11 +442,11 @@ def __init__( def _quantized_op(act_mat, w_qtensor, bias): orig_act_size = act_mat.size() orig_dtype = act_mat.dtype - # reshape and pad activation act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) - pad_size = find_multiple(act_mat.shape[-1], 1024) - act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) + # Any padding for weight? Otherwise it may cause the mismatch of the shape of the input and the weight + # pad_size = find_multiple(act_mat.shape[-1], 1024) + # act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # matmul y = aten._weight_int4pack_mm( @@ -467,6 +468,7 @@ def _quantized_op(act_mat, w_qtensor, bias): return y.to(orig_dtype) def dequantize(self): + breakpoint() eye_shape = self.shape[1] if not self.transposed else self.shape[0] w_dq = self._quantized_op( torch.eye(eye_shape, device=self.device, dtype=self.dtype), self, None From 6433e756479d1b1e09fb8494b4f3fb18f91319b2 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 26 Jul 2024 04:55:51 -0400 Subject: [PATCH 4/5] remove debug code Signed-off-by: yiliu30 --- torchao/quantization/subclass.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 18751302b3..05166ddf19 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -126,7 +126,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs if func is torch.nn.functional.linear: - breakpoint() mat1, w_qtensor, bias = ( args[0], args[1], @@ -442,6 +441,7 @@ def __init__( def _quantized_op(act_mat, w_qtensor, bias): orig_act_size = act_mat.size() orig_dtype = act_mat.dtype + # reshape and pad activation act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) # Any padding for weight? Otherwise it may cause the mismatch of the shape of the input and the weight @@ -468,7 +468,6 @@ def _quantized_op(act_mat, w_qtensor, bias): return y.to(orig_dtype) def dequantize(self): - breakpoint() eye_shape = self.shape[1] if not self.transposed else self.shape[0] w_dq = self._quantized_op( torch.eye(eye_shape, device=self.device, dtype=self.dtype), self, None From 65f46e5f5d7e3f48ea96f3a859919367d7941c71 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 29 Jul 2024 05:04:14 -0400 Subject: [PATCH 5/5] update the forward Signed-off-by: yiliu30 --- torchao/prototype/autoround/__init__.py | 0 .../prototype/autoround/auto_round_flow.py | 222 ++++-------------- torchao/prototype/autoround/utils.py | 108 +++++++++ 3 files changed, 158 insertions(+), 172 deletions(-) create mode 100644 torchao/prototype/autoround/__init__.py create mode 100644 torchao/prototype/autoround/utils.py diff --git a/torchao/prototype/autoround/__init__.py b/torchao/prototype/autoround/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/autoround/auto_round_flow.py b/torchao/prototype/autoround/auto_round_flow.py index 56d52f6a93..5a602cc60c 100644 --- a/torchao/prototype/autoround/auto_round_flow.py +++ b/torchao/prototype/autoround/auto_round_flow.py @@ -1,59 +1,7 @@ -# ==------------------------------------------------------------------------------------------== -# Utils -# ==------------------------------------------------------------------------------------------== -import os - -seed = 0 -import random - -random.seed(seed) -import torch - -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -import numpy as np - -np.random.seed(seed) +from torchao.prototype.autoround.utils import freeze_random, assert_same, get_dataloader from typing import Optional, Callable, Any, List, Tuple, Dict - -def assert_same( - a: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]], - b: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]], -): - assert len(a) == len(b), f"len: {len(a)} != {len(b)}" - for i, _ in enumerate(a): - assert type(a[i]) == type(b[i]), f"type: {type(a[i])} != {type(b[i])}" - if isinstance(a[i], torch.Tensor): - torch.testing.assert_allclose(a[i], b[i]) - elif isinstance(a[i], tuple): - assert_same(a[i], b[i]) - elif isinstance(a[i], dict): - for k in a[i].keys(): - assert k in b[i], f"key: {k} not in {b[i]}" - assert_same(a[i][k], b[i].get(k)) - elif a[i] is None: - assert b[i] is None - else: - raise ValueError(f"Unsupported type: {type(a[i])}") - print("Same!") - - -def inspect_module_inputs(inputs, indent=""): - if isinstance(inputs, torch.Tensor): - print(f"{indent}Tensor: {inputs.shape}") - elif isinstance(inputs, tuple) or isinstance(inputs, list): - for i in inputs: - inspect_module_inputs(i, indent + " ") - elif isinstance(inputs, dict): - for k, v in inputs.items(): - print(f"{indent}{k}:") - inspect_module_inputs(v, indent + " ") - elif inputs is None: - print(f"{indent}None") - else: - print(f"{indent}{type(inputs)}") - +freeze_random() # ==------------------------------------------------------------------------------------------== # TorchAO @@ -63,6 +11,7 @@ def inspect_module_inputs(inputs, indent=""): import torchao.quantization as ao_quant from functools import partial import transformers +import os def create_qmodel_from_qdq_model(qdq_model): @@ -72,12 +21,11 @@ def _is_quantized_linear(model, fqn): @torch.no_grad() def create_qlinear(linear): - def _get_qdq_data(linear): - # TODO: refine the impl, refer: https://github.com/yiliu30/auto-round/pull/2 - # qdq_weight shape: (m, k) + def _get_qinfo(linear): + # qdq_weight shape: (oc, ic) qdq_weight = linear.weight.clone() device = qdq_weight.device - # scales, zeros shape: (m, n_groups) + # scales, zeros shape: (oc, n_groups) scales = linear.scale.to(device) zeros = linear.zp.to(device) @@ -110,7 +58,7 @@ def _get_qdq_data(linear): .to(torch.int32) ) - # Shift the zeros to align with tinnygemm + # Shift the zeros to align with tinygemm # TODO: more notes or discard this step zeros = (8 - zeros) * scales @@ -124,14 +72,7 @@ def _get_qdq_data(linear): ) return packed_int_data, scales_and_zeros - # For Debug - int_data, scales_and_zeros = _get_qdq_data(linear) - float_out = None - random_input = torch.randn((4, linear.weight.shape[1]), dtype=torch.bfloat16).to(int_data.device) - if os.environ.get("DEBUG", "0") == "1": - float_out = linear(random_input.to(linear.weight.dtype)) - - # TODO: Double check below args + int_data, scales_and_zeros = _get_qinfo(linear) woq_weight = ao_quant.Int4WeightOnlyQuantizedLinearWeight( int_data, scales_and_zeros, @@ -144,13 +85,6 @@ def _get_qdq_data(linear): linear.weight = torch.nn.Parameter(woq_weight, requires_grad=False) del linear.scale del linear.zp - if os.environ.get("DEBUG", "0") == "1": - new_lin_int_out = linear(random_input) - # _max = (float_out - new_lin_int_out).abs().max() - # print(f"Max diff between float and int: {_max}") - assert torch.allclose( - float_out.to(new_lin_int_out.dtype), new_lin_int_out, atol=2e-1 - ), f"The max diff between float and int is too large: {(float_out - new_lin_int_out).abs().max()}" return linear qmodel = ao_quant.quant_api._replace_with_custom_fn_if_matches_filter( @@ -164,7 +98,9 @@ class ModuleInputCapture(torch.nn.Module): def __init__(self): super().__init__() - self.inputs: List[Any] = [] + # [(args, kwargs), ...] + self.inputs: List[Tuple[Tuple[Any], Dict[str, Any]]] = [] + self.outputs = [] def forward(self, *args, **kwarsg): self.inputs.append((args, kwarsg)) @@ -174,23 +110,30 @@ def __repr__(self): class ObservedBlock(torch.nn.Module): - def __init__(self, float_block: torch.nn.Module, block_observer: ModuleInputCapture, input_hook_handle=None): + def __init__( + self, + float_block: torch.nn.Module, + block_observer: ModuleInputCapture, + input_hook_handle=None, + output_hook_handle=None, + ): super().__init__() # e.g., replace `transformers.models.llama.modeling_llama.LlamaDecoderLayer` self.float_block = float_block self.block_observer = block_observer self.input_hook_handle = input_hook_handle + self.output_hook_handle = output_hook_handle - def remove_input_hook_handle(self): + def remove_hook_handles(self): self.input_hook_handle.remove() + self.output_hook_handle.remove() def forward(self, *args, **kwarsg): return self.float_block(*args, **kwarsg) @classmethod def from_float(cls, float_block: torch.nn.Module, block_observer: ModuleInputCapture = None): - # TODO: only insert hook to the float_block to capture the input and save it to the block_observer - # TODO: look like no need new module for it? + # TODO: remove `block_observer`? def capture_inputs_hook( block_observer: ModuleInputCapture, module: torch.nn.Module, @@ -200,16 +143,28 @@ def capture_inputs_hook( block_observer.inputs.append((args, kwargs)) return args, kwargs + def capture_outputs_hook( + block_observer: ModuleInputCapture, + module: torch.nn.Module, + inputs, + outputs, + ): + block_observer.outputs.append(outputs) + return outputs + if block_observer is None: block_observer = ModuleInputCapture() - hook_handle = float_block.register_forward_pre_hook( + pre_forward_hook_handle = float_block.register_forward_pre_hook( partial(capture_inputs_hook, block_observer), with_kwargs=True ) - return cls(float_block, block_observer, hook_handle) + forward_hook_handle = float_block.register_forward_hook(partial(capture_outputs_hook, block_observer)) + return cls(float_block, block_observer, pre_forward_hook_handle, forward_hook_handle) - def get_module_inputs(self): + def get_module_inputs_outputs(self): + self.remove_hook_handles() inputs = self.block_observer.inputs - return inputs + outputs = self.block_observer.outputs + return inputs, outputs def insert_observers_for_block_( @@ -221,29 +176,14 @@ def insert_observers_for_block_( def apply_auto_round(observed_block: ObservedBlock): - block_inputs = observed_block.block_observer.inputs - # first_inputs = block_inputs[0][1] - # hidden_states = first_inputs["hidden_states"] - # position_ids = first_inputs["position_ids"] - # attention_mask = first_inputs["attention_mask"] - - # # WA for now - # _input_ids = hidden_states - # _input_others = {"positional_inputs": position_ids, "attention_mask": attention_mask} - - _input_ids = block_inputs[0][0][0].detach() - position_ids = [] - attention_mask = block_inputs[0][1]["attention_mask"].detach() - _input_others = {"positional_inputs": position_ids, "attention_mask": attention_mask} - - seq_len = _input_ids.shape[1] + block_inputs, block_outputs = observed_block.get_module_inputs_outputs() # Call the autoround to execute the optimization process import auto_round block = observed_block.float_block - block.dtype = next(block.parameters()).dtype - # Start the training process to update the v and alpha and betta. + # Start the training process to update the v, alpha and betta. + # TODO: refactor the `quant_block_new` to a static method rounder = auto_round.AutoRound( model=block, tokenizer=None, @@ -251,14 +191,11 @@ def apply_auto_round(observed_block: ObservedBlock): bits=4, iters=2, use_quant_input=False, # disable it for now - n_samples=1, # double-check it amp=False, - seqlen=seq_len, - batch_size=1, low_gpu_mem_usage=False, + model_dtype=next(block.parameters()).dtype, ) - # TODO: rename the `quant_block` to `quant_block_` - rounder.quant_block(block, input_ids=_input_ids, input_others=_input_others, device=torch.device("cuda")) + rounder.quant_block_new(block, block_inputs, block_outputs) return create_qmodel_from_qdq_model(observed_block) @@ -267,40 +204,14 @@ def apply_auto_round(observed_block: ObservedBlock): # ==------------------------------------------------------------------------------------------== -import pytest - - class TestFlow: - @torch.no_grad() - def test_obseverblock(self): - model = transformers.AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-GPTJForCausalLM") - block = model.transformer.h[0] - observed_block = ObservedBlock.from_float(block) - bs, seq_len, hidden_size = 2, 3, 32 - hidden_states = torch.randn((bs, seq_len, hidden_size)) - position_ids = torch.randint(0, seq_len, (bs, seq_len)) - # Record the input and output of the block - origin_output = [] - out1 = observed_block(hidden_states, position_ids=position_ids) - origin_output.append(out1) - print(observed_block.block_observer.inputs) - attention_mask = torch.randn(bs, 4, seq_len, seq_len) - out2 = observed_block(hidden_states, None, attention_mask, position_ids=position_ids) - origin_output.append(out2) - print(observed_block.block_observer.inputs) - observed_block.remove_input_hook_handle() - # Replay - new_output = [] - for args, kwargs in observed_block.block_observer.inputs: - out = observed_block(*args, **kwargs) - new_output.append(out) - assert_same(origin_output, new_output) - def test_with_opt(self): + # ==------------------------------------------------------------------------------------------== + # The Modeling User API + # ==------------------------------------------------------------------------------------------== with torch.no_grad(): - device = torch.device("cuda") - # Step 0. Load the float model + device = torch.device("cuda") import transformers # pretrained_model_name_or_path = "hf-internal-testing/tiny-random-GPTJForCausalLM" @@ -312,54 +223,21 @@ def test_with_opt(self): # Similar with the `insert_observers_`, but for block is_block = lambda model, fqn: isinstance(model, transformers.models.opt.modeling_opt.OPTDecoderLayer) # is_block = lambda model, fqn: isinstance(model, transformers.models.gptj.modeling_gptj.GPTJBlock) - block_observer = ModuleInputCapture() insert_observers_for_block_(model, is_block) - print(f"model with observer (before calibration): \n{model}") + print(f"Model with observer (before calibration): \n{model}") # Step 2. calibrating / training # For capturing the input of block - # batch_size, seq_len, hidden_size = 2, 5, 32 + # TODO: replace it with a real calibration dataset iters = 4 prompt = "The meaning of life is" - # "input_ids", "attention_mask" example_inputs = tokenizer(prompt, return_tensors="pt") for _ in range(iters): model(**example_inputs.to(device)) - print(f"model with observer (after calibration): \n{model}") + print(f"Model with observer (after calibration): \n{model}") # Step 3. quantize the block is_observed_block = lambda model, fqn: isinstance(model, ObservedBlock) ao_quant.quantize_(model, apply_auto_round, is_observed_block) - - -# ==------------------------------------------------------------------------------------------== -# The Modeling User API -# ==------------------------------------------------------------------------------------------== - - -def test_user_api(): - # Step 0. Load the float model - import transformers - - pretrained_model_name_or_path = "facebook/opt-125m" - model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) - - # Step 1. replace the block with an observed block - # Similar with the `insert_observers_`, but for block - - is_block = lambda model, fqn: isinstance(model, transformers.models.opt.modeling_opt.OPTDecoderLayer) - insert_observers_for_block_(model, is_block) - - # Step 2. calibrating / training - # For capturing the input of block - batch_size, seq_len, hidden_size = 2, 10, 768 - example_inputs = torch.rannd((batch_size, seq_len, hidden_size)) - - for _ in range(10): - model(*example_inputs) - - # Step 3. quantize the block - is_observed_block = lambda model, fqn: isinstance(model, ObservedBlock) - ao_quant.quantize_(model, apply_auto_round, is_observed_block) diff --git a/torchao/prototype/autoround/utils.py b/torchao/prototype/autoround/utils.py new file mode 100644 index 0000000000..6b400a35cf --- /dev/null +++ b/torchao/prototype/autoround/utils.py @@ -0,0 +1,108 @@ +# ==------------------------------------------------------------------------------------------== +# Utils +# ==------------------------------------------------------------------------------------------== +from typing import Optional, Callable, Any, List, Tuple, Dict + +import random +import os +import torch +import numpy as np + + +def freeze_random(): + seed = 0 + + random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + np.random.seed(seed) + + +def assert_same( + a: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]], + b: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]], +): + assert len(a) == len(b), f"len: {len(a)} != {len(b)}" + for i, _ in enumerate(a): + assert type(a[i]) == type(b[i]), f"type: {type(a[i])} != {type(b[i])}" + if isinstance(a[i], torch.Tensor): + torch.testing.assert_allclose(a[i], b[i]) + elif isinstance(a[i], tuple): + assert_same(a[i], b[i]) + elif isinstance(a[i], dict): + for k in a[i].keys(): + assert k in b[i], f"key: {k} not in {b[i]}" + assert_same(a[i][k], b[i].get(k)) + elif a[i] is None: + assert b[i] is None + else: + raise ValueError(f"Unsupported type: {type(a[i])}") + print("Same!") + + +def inspect_module_inputs(inputs, indent=""): + if isinstance(inputs, torch.Tensor): + print(f"{indent}Tensor: {inputs.shape}") + elif isinstance(inputs, tuple) or isinstance(inputs, list): + for i in inputs: + inspect_module_inputs(i, indent + " ") + elif isinstance(inputs, dict): + for k, v in inputs.items(): + print(f"{indent}{k}:") + inspect_module_inputs(v, indent + " ") + elif inputs is None: + print(f"{indent}None") + else: + print(f"{indent}{type(inputs)}") + + +def get_tokenizer_function(tokenizer, seqlen): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of + seqlen to the "text" field of examples. + """ + + def default_tokenizer_function(examples): + example = tokenizer(examples["text"], truncation=True, max_length=seqlen) + return example + + return default_tokenizer_function + + +def get_dataloader(tokenizer, seqlen=1024, dataset_name="NeelNanda/pile-10k", split="train", seed=42, batch_size=4): + from datasets import load_dataset + from torch.utils.data import DataLoader + + tokenizer_function = get_tokenizer_function(tokenizer, seqlen) + + @torch.no_grad() + def collate_batch(batch): + input_ids_new = [] + for text in batch: + input_ids = text["input_ids"] + if input_ids.shape[0] < seqlen: + continue + input_ids = input_ids[:seqlen] + input_ids_list = input_ids.tolist() + if input_ids_list.count(input_ids_list[-1]) > seqlen // 2: + continue + input_ids_new.append(input_ids) + if len(input_ids_new) == 0: + return None + tmp = torch.vstack(input_ids_new) + res = {"input_ids": tmp} + return res + + calib_dataset = load_dataset(dataset_name, split=split) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + calib_dataset.set_format(type="torch", columns=["input_ids"]) + calib_dataloader = DataLoader(calib_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch) + return calib_dataloader