diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 601408fceb..318136be56 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -92,7 +92,18 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: input_specs_val = ( self.lower_setting.input_specs if self.lower_setting.input_specs - else InputTensorSpec.from_tensors(input) + else ( + InputTensorSpec.from_tensors_with_dynamic_batch_size( + input, + ( + 0, + self.lower_setting.max_batch_size, + self.lower_setting.max_batch_size, + ), + ) + if self.lower_setting.explicit_batch_dimension + else InputTensorSpec.from_tensors(input) + ) ) # Prepare algorithm selector and timing_cache for TRTInterpreter diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index b9cbb2630d..b1a32c2cff 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -63,6 +63,12 @@ class LowerSetting(LowerSettingBasic): save_timing_cache: Save updated timing cache data into timing cache file if the timing cache file is provided. cuda_graph_batch_size (int): Cuda graph batch size, default to be -1. + preset_lowerer (str): when specified, use a preset logic to build the + instance of Lowerer. Refer to + `caffe2.torch.fb.model_transform.fx2trt.presets.LowererPresetsManager` on + how presets are applied. Refer to + `caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how + to add a preset. """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -79,3 +85,4 @@ class LowerSetting(LowerSettingBasic): timing_cache_prefix: str = "" save_timing_cache: bool = False cuda_graph_batch_size: int = -1 + preset_lowerer: str = "" diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index 6dc2e86f22..4394ca97b4 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -31,7 +31,9 @@ def skip_folding_quant_dequant(node: torch.fx.Node): return True return False - const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant) + const_split_mod = split_const_subgraphs( + traced_mod, skip_folding_quant_dequant, device_for_folded_attrs="cuda" + ) const_split_mod.run_folding() return const_split_mod diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 231b8eed0c..a78329c9ef 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -2576,5 +2576,6 @@ def test_all_acc_ops_registered(self): acc_ops.new_ones, acc_ops.einsum, acc_ops.as_strided, + acc_ops.var, }, ) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index df6480166b..b28bf263c2 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -2864,3 +2864,18 @@ def as_strided(*, input, size, stride, storage_offset=0): return torch.as_strided( input=input, size=size, stride=stride, storage_offset=storage_offset ) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.var)) +@register_acc_op_mapping( + op_and_target=("call_method", "var"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("unbiased", "unbiased"), + ("keepdim", "keepdim"), + ], +) +@register_acc_op +def var(*, input, dim, unbiased, keepdim=False): + return torch.var(input=input, dim=dim, unbiased=unbiased, keepdim=keepdim)