From cdb8f8abfc87409a7e4bb7602e6fb221e014ba48 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Fri, 19 May 2023 17:47:49 +0000 Subject: [PATCH 1/3] add 3 tags --- .../dynamo/fx_ts_compat/fx2trt.py | 18 ++++ .../dynamo/fx_ts_compat/lower.py | 3 + .../dynamo/fx_ts_compat/lower_setting.py | 6 ++ .../dynamo/fx_ts_compat/passes/pass_utils.py | 87 ++++++++++--------- 4 files changed, 72 insertions(+), 42 deletions(-) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index b5165c6f2d..6848d75527 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -163,6 +163,9 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, + max_aux_streams=None, + version_compatible=False, + optimization_level=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -227,6 +230,18 @@ def run( if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) + + if trt.__version__ >= "8.6": + if max_aux_streams is not None: + _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") + builder_config.max_aux_streams = max_aux_streams + if version_compatible: + _LOGGER.info(f"Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if optimization_level is not None: + _LOGGER.info(f"Using optimization level {optimization_level}") + builder_config.builder_optimization_level = optimization_level + if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -264,6 +279,9 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) + _LOGGER.info( + f"TRT Engine uses: {engine.device_memory_size} bytes of Memory" + ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py index 60ace0f12a..8131edb540 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -181,6 +181,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, + max_aux_streams=self.lower_setting.max_aux_streams, + version_compatible=self.lower_setting.version_compatible, + optimization_level=self.lower_setting.optimization_level, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py index 9008bbe8e9..64fa1bf267 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -70,6 +70,9 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + max_aux_streams: max number of aux stream to use + version_compatible: enable version compatible feature + optimization_level: builder optimization level """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -96,3 +99,6 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + max_aux_streams: Optional[int] = None + version_compatible: bool = False + optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py index 96fa96cfae..5252f537f5 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py @@ -126,7 +126,7 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. def validate_inference( - rtol=None, atol=None, device=torch.device(torch.cuda.current_device()) + rtol=None, atol=None, device=torch.device(torch.cuda.current_device(), suppress_accuracy_check_failure=True) ): def _validate_inference(pass_: PassFunc) -> PassFunc: """ @@ -141,48 +141,51 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - input_tensors = extract_example_tensors_from_input(input, device) - res0 = module(*input_tensors) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input_tensors) - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) + if suppress_accuracy_check_failure: + return pass_(module, input, *args, **kwargs) + else: + input_tensors = extract_example_tensors_from_input(input, device) + res0 = module(*input_tensors) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input_tensors) + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation From 402abcfd84cc67fd26f8af841a971b3ccd0234e5 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Fri, 19 May 2023 23:52:41 +0000 Subject: [PATCH 2/3] fix typo --- py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py index 5252f537f5..54eea0ad30 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py @@ -126,7 +126,7 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. def validate_inference( - rtol=None, atol=None, device=torch.device(torch.cuda.current_device(), suppress_accuracy_check_failure=True) + rtol=None, atol=None, device=torch.device(torch.cuda.current_device()), suppress_accuracy_check_failure=True ): def _validate_inference(pass_: PassFunc) -> PassFunc: """ From d660256ed22ddf7dd12f8e465476eba4365b2c89 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Mon, 22 May 2023 18:07:37 +0000 Subject: [PATCH 3/3] reformat --- py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py | 4 +--- py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index 6848d75527..e4298600cb 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -279,9 +279,7 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info( - f"TRT Engine uses: {engine.device_memory_size} bytes of Memory" - ) + _LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory") return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py index 54eea0ad30..7d3046d617 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py @@ -126,7 +126,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. def validate_inference( - rtol=None, atol=None, device=torch.device(torch.cuda.current_device()), suppress_accuracy_check_failure=True + rtol=None, + atol=None, + device=torch.device(torch.cuda.current_device()), + suppress_accuracy_check_failure=True, ): def _validate_inference(pass_: PassFunc) -> PassFunc: """ @@ -141,7 +144,7 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - if suppress_accuracy_check_failure: + if suppress_accuracy_check_failure: return pass_(module, input, *args, **kwargs) else: input_tensors = extract_example_tensors_from_input(input, device)