diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index adf1057055..f93b097385 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -70,6 +70,7 @@ min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, make_refittable=True, + reuse_cached_engines=False, ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 19d80e70b1..1358e034c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -477,12 +477,18 @@ def _save_weight_mapping(self) -> None: # Retrieve each weight name(s) in state_dict if layer_type == "CONSTANT": if "embedding" in suffix: - sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" + sd_weight_name = f"{sd_weight_name}.weight" elif "weight" in suffix or "mm_other" in suffix: # Linear layer weight - sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" + sd_weight_name = f"{sd_weight_name}.weight" + elif "running_mean" in suffix: + # Linear layer weight + sd_weight_name = f"{sd_weight_name}.running_mean" + elif "running_var" in suffix: + # Linear layer weight + sd_weight_name = f"{sd_weight_name}.running_var" else: - sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}" + sd_weight_name = f"{sd_weight_name}.bias" elif layer_type == "SCALE": # Batch norm needs all weights to calculate scale and shift sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 587bfd0373..0f8baa76b7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -50,14 +50,27 @@ def batch_norm( # Save the original output shape for later use output_shape = input.shape - if weight is None: - weight = get_trt_tensor(ctx, 1.0, f"{name}_weight") - if bias is None: - bias = get_trt_tensor(ctx, 0.0, f"{name}_bias") - if running_mean is None: - running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean") - if running_var is None: - running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var") + # We name the weight here according to the state_dict name + weight = ( + get_trt_tensor(ctx, 1.0, f"{name}_weight") + if weight is None + else get_trt_tensor(ctx, weight, f"{name}_weight") + ) + bias = ( + get_trt_tensor(ctx, 0.0, f"{name}_bias") + if bias is None + else get_trt_tensor(ctx, bias, f"{name}_bias") + ) + running_mean = ( + get_trt_tensor(ctx, 0.0, f"{name}_running_mean") + if running_mean is None + else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") + ) + running_var = ( + get_trt_tensor(ctx, 1.0, f"{name}_running_var") + if running_var is None + else get_trt_tensor(ctx, running_var, f"{name}_running_var") + ) # eps_tensor for numerical stability eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps") diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 0f6fb05914..46ffa7b6d8 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -35,8 +35,8 @@ @pytest.mark.unit def test_mapping(): - model = models.resnet18(pretrained=True).eval().to("cuda") - model2 = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] trt_input = [ torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format) @@ -58,6 +58,7 @@ def test_mapping(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) settings = trt_gm._run_on_acc_0.settings runtime = trt.Runtime(TRT_LOGGER) @@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_no_map_with_weightmap(): - model = models.resnet18(pretrained=True).eval().to("cuda") - model2 = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) trt_gm._run_on_acc_0.weight_name_map = None @@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_with_wrong_weightmap(): - model = models.resnet18(pretrained=True).eval().to("cuda") - model2 = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) # Manually Deleted all batch norm layer. This suppose to fail the fast refit trt_gm._run_on_acc_0.weight_name_map = { @@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_inline_runtime__with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - model = models.resnet18(pretrained=True).eval().to("cuda") - model2 = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) trt_gm = torch.export.load(trt_ep_path) @@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): @pytest.mark.unit def test_refit_one_engine_python_runtime_with_weightmap(): - model = models.resnet18(pretrained=True).eval().to("cuda") - model2 = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -438,6 +445,7 @@ def forward(self, x): min_block_size=min_block_size, make_refittable=True, torch_executed_ops=torch_executed_ops, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -487,6 +495,7 @@ def test_refit_one_engine_without_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) trt_gm = torch.export.load(trt_ep_path) @@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): debug=debug, min_block_size=min_block_size, make_refittable=True, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -708,6 +720,7 @@ def forward(self, x): min_block_size=min_block_size, make_refittable=True, torch_executed_ops=torch_executed_ops, + reuse_cached_engines=False, ) new_trt_gm = refit_module_weights(