From fe6b93b19b82edba40535f4f58502701154ff71b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 17 Aug 2024 11:14:43 +0800 Subject: [PATCH 1/6] add device argument to quantize_() --- torchao/quantization/quant_api.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a0ad665eac..7e36a1699a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -161,6 +161,7 @@ def _replace_with_custom_fn_if_matches_filter( replacement_fn, filter_fn, cur_fqn="", + device=None, ) -> None: """ Recursively replaces each child module in `model` with the result of `replacement_fn(child)` @@ -171,20 +172,23 @@ def _replace_with_custom_fn_if_matches_filter( replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules. filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". + device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. Returns: None """ if filter_fn(model, cur_fqn[:-1]): + model.to(device=device) # move to device before quantization model = replacement_fn(model) return model else: for name, child in model.named_children(): new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}." + child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device ) if new_child is not child: setattr(model, name, new_child) + model.to(device=device) # move parent module to device return model @@ -269,7 +273,13 @@ def insert_subclass(lin): return insert_subclass -def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): +def quantize_( + model: torch.nn.Module, + apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, + set_inductor_config: bool = True, + device: Optional[torch.types.Device] = None, +): """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace Args: @@ -278,6 +288,8 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn. filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) + device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`. + Defaults to None (do not change device). Example:: @@ -329,6 +341,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: model, apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, + device=device, ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: From 8fee8f88b46ba2a2559e2cdac8384268ee9756b2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 17 Aug 2024 18:56:16 +0800 Subject: [PATCH 2/6] fix test --- torchao/quantization/quant_api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7e36a1699a..49345adf39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -178,7 +178,8 @@ def _replace_with_custom_fn_if_matches_filter( None """ if filter_fn(model, cur_fqn[:-1]): - model.to(device=device) # move to device before quantization + if device is not None: + model.to(device=device) # move to device before quantization model = replacement_fn(model) return model else: @@ -188,7 +189,8 @@ def _replace_with_custom_fn_if_matches_filter( ) if new_child is not child: setattr(model, name, new_child) - model.to(device=device) # move parent module to device + if device is not None: + model.to(device=device) # move parent module to device return model From 4550fc5b6cf55e8751dcf7a16890eb14e8e732e2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 17 Aug 2024 21:20:39 +0800 Subject: [PATCH 3/6] add test --- test/quantization/test_quant_api.py | 34 +++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index a1b29320de..023d495f0b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -54,6 +54,8 @@ from torchao.utils import unwrap_tensor_subclass import copy import tempfile +import gc +import time from torch.testing._internal.common_utils import TestCase @@ -680,6 +682,38 @@ def test_quantized_tensor_subclass_save_load_map_location(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_model_streaming(self): + def reset_memory(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + reset_memory() + m = ToyLinearModel() + time0 = time.perf_counter() + m.to(device="cuda") + quantize_(m, int8_weight_only()) + torch.cuda.synchronize() + time_baseline = time.perf_counter() - time0 + memory_baseline = torch.cuda.max_memory_allocated() + print(memory_baseline) + + del m + reset_memory() + m = ToyLinearModel() + time0 = time.perf_counter() + quantize_(m, int8_weight_only(), device="cuda") + time_streaming = time.perf_counter() - time0 + memory_streaming = torch.cuda.max_memory_allocated() + print(memory_streaming) + + for param in m.parameters(): + assert param.is_cuda + self.assertLess(time_streaming, time_baseline * 1.1) + self.assertLess(memory_streaming, memory_baseline) + if __name__ == "__main__": unittest.main() From a53a5ff161b3ede4248493d25e9f552ee4632ea5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 17 Aug 2024 21:23:30 +0800 Subject: [PATCH 4/6] remove print --- test/quantization/test_quant_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 023d495f0b..0cc5df618b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -698,7 +698,6 @@ def reset_memory(): torch.cuda.synchronize() time_baseline = time.perf_counter() - time0 memory_baseline = torch.cuda.max_memory_allocated() - print(memory_baseline) del m reset_memory() @@ -707,7 +706,6 @@ def reset_memory(): quantize_(m, int8_weight_only(), device="cuda") time_streaming = time.perf_counter() - time0 memory_streaming = torch.cuda.max_memory_allocated() - print(memory_streaming) for param in m.parameters(): assert param.is_cuda From 5218d43f62577c781e88c88b868b5e917dfb16d6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 17 Aug 2024 22:54:38 +0800 Subject: [PATCH 5/6] fix --- test/quantization/test_quant_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 0cc5df618b..4465b48ba9 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -704,12 +704,13 @@ def reset_memory(): m = ToyLinearModel() time0 = time.perf_counter() quantize_(m, int8_weight_only(), device="cuda") + torch.cuda.synchronize() time_streaming = time.perf_counter() - time0 memory_streaming = torch.cuda.max_memory_allocated() for param in m.parameters(): assert param.is_cuda - self.assertLess(time_streaming, time_baseline * 1.1) + self.assertLess(time_streaming, time_baseline * 1.5) self.assertLess(memory_streaming, memory_baseline) From 711d0010144cbc3d9538e7572d5b2532fe4a1068 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 18 Aug 2024 00:02:54 +0800 Subject: [PATCH 6/6] remove timing check --- test/quantization/test_quant_api.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 4465b48ba9..e61ce7b4fb 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -55,7 +55,6 @@ import copy import tempfile import gc -import time from torch.testing._internal.common_utils import TestCase @@ -692,25 +691,17 @@ def reset_memory(): reset_memory() m = ToyLinearModel() - time0 = time.perf_counter() - m.to(device="cuda") - quantize_(m, int8_weight_only()) - torch.cuda.synchronize() - time_baseline = time.perf_counter() - time0 + quantize_(m.to(device="cuda"), int8_weight_only()) memory_baseline = torch.cuda.max_memory_allocated() del m reset_memory() m = ToyLinearModel() - time0 = time.perf_counter() quantize_(m, int8_weight_only(), device="cuda") - torch.cuda.synchronize() - time_streaming = time.perf_counter() - time0 memory_streaming = torch.cuda.max_memory_allocated() for param in m.parameters(): assert param.is_cuda - self.assertLess(time_streaming, time_baseline * 1.5) self.assertLess(memory_streaming, memory_baseline)