diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index a1b29320de..e61ce7b4fb 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -54,6 +54,7 @@ from torchao.utils import unwrap_tensor_subclass import copy import tempfile +import gc from torch.testing._internal.common_utils import TestCase @@ -680,6 +681,29 @@ def test_quantized_tensor_subclass_save_load_map_location(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_model_streaming(self): + def reset_memory(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + reset_memory() + m = ToyLinearModel() + quantize_(m.to(device="cuda"), int8_weight_only()) + memory_baseline = torch.cuda.max_memory_allocated() + + del m + reset_memory() + m = ToyLinearModel() + quantize_(m, int8_weight_only(), device="cuda") + memory_streaming = torch.cuda.max_memory_allocated() + + for param in m.parameters(): + assert param.is_cuda + self.assertLess(memory_streaming, memory_baseline) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a0ad665eac..49345adf39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -161,6 +161,7 @@ def _replace_with_custom_fn_if_matches_filter( replacement_fn, filter_fn, cur_fqn="", + device=None, ) -> None: """ Recursively replaces each child module in `model` with the result of `replacement_fn(child)` @@ -171,20 +172,25 @@ def _replace_with_custom_fn_if_matches_filter( replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules. filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". + device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. Returns: None """ if filter_fn(model, cur_fqn[:-1]): + if device is not None: + model.to(device=device) # move to device before quantization model = replacement_fn(model) return model else: for name, child in model.named_children(): new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}." + child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device ) if new_child is not child: setattr(model, name, new_child) + if device is not None: + model.to(device=device) # move parent module to device return model @@ -269,7 +275,13 @@ def insert_subclass(lin): return insert_subclass -def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): +def quantize_( + model: torch.nn.Module, + apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, + set_inductor_config: bool = True, + device: Optional[torch.types.Device] = None, +): """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace Args: @@ -278,6 +290,8 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn. filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) + device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`. + Defaults to None (do not change device). Example:: @@ -329,6 +343,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: model, apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, + device=device, ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: