From 076e0f74da385e1987d4663c651783cdda35cb1e Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 12 Apr 2024 13:50:43 -0700 Subject: [PATCH] fix: Missing parameters in compiler settings --- py/torch_tensorrt/ts/_compiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 4a9bb53dc0..e101ebe25d 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -4,11 +4,12 @@ import torch import torch_tensorrt._C.ts as _C -from torch_tensorrt import _enums from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device +from torch_tensorrt import _enums + def compile( module: torch.jit.ScriptModule, @@ -137,6 +138,9 @@ def compile( "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels "num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels "workspace_size": workspace_size, # Maximum size of workspace given to TensorRT + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, "calibrator": calibrator, "truncate_long_and_double": truncate_long_and_double, "torch_fallback": {