Skip to content

Commit 69e49e8

Browse files
committed
feat: update truncate long/double python api
Signed-off-by: inocsin <[email protected]>
1 parent 740eb54 commit 69e49e8

File tree

6 files changed

+18
-4
lines changed

6 files changed

+18
-4
lines changed

core/conversion/var/Var.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,14 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9797
auto weights = converters::Weights();
9898
if (isIValue()) {
9999
auto tensor = ptr_.ivalue->toTensor();
100-
if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
100+
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) {
101+
TRTORCH_CHECK(0, "Unable to freeze tensor of type kLong/kDouble into constant layer, try to compile model with truncate_long_and_double ON");
102+
} else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
101103
weights = converters::Weights(ctx, tensor.toType(at::kInt));
102-
LOG_WARNING("Truncate kLong to kInt for IValue");
104+
LOG_WARNING("Warning: Truncating weight (constant in the graph) from kLong to kInt to indicate that only constants are affected.");
103105
} else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) {
104106
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
105-
LOG_WARNING("Truncate kDouble to kFloat for IValue");
107+
LOG_WARNING("Warning: Truncating weight (constant in the graph) from kDouble to kFloat to indicate that only constants are affected.");
106108
} else {
107109
weights = converters::Weights(ctx, tensor);
108110
}

py/trtorch/_compile_spec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
176176
if "max_batch_size" in compile_spec:
177177
assert type(compile_spec["max_batch_size"]) is int
178178
info.max_batch_size = compile_spec["max_batch_size"]
179+
180+
if "truncate_long_and_double" in compile_spec:
181+
assert type(compile_spec["truncate_long_and_double"]) is bool
182+
info.truncate_long_and_double = compile_spec["truncate_long_and_double"]
179183

180184
return info
181185

@@ -217,6 +221,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
217221
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
218222
"workspace_size": 0, # Maximum size of workspace given to TensorRT
219223
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
224+
"truncate_long_and_double": False, # Truncate long and double into int and float
220225
})
221226
}
222227
@@ -257,6 +262,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
257262
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
258263
backend_spec.set_workspace_size(parsed_spec.workspace_size)
259264
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
265+
backend_spec.set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
260266
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())
261267

262268
return backend_spec

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void RegisterTRTCompileSpec() {
4242
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
4343
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
4444
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);
45+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, truncate_long_and_double);
4546
}
4647

4748
struct TRTTSRegistrations {

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
108108
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
109109
info.convert_info.engine_settings.device.dla_core = device.dla_core;
110110
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
111+
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
111112

112113
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
113114
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
@@ -143,6 +144,7 @@ std::string CompileSpec::stringify() {
143144
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
144145
ss << " \"Workspace Size\": " << workspace_size << std::endl;
145146
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
147+
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
146148
ss << "}";
147149
return ss.str();
148150
}

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct CompileSpec : torch::CustomClassHolder {
115115
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
116116
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
117117
ADD_FIELD_GET_SET(workspace_size, int64_t);
118+
ADD_FIELD_GET_SET(truncate_long_and_double, bool);
118119
ADD_FIELD_GET_SET(max_batch_size, int64_t);
119120
ADD_FIELD_GET_SET(device, Device);
120121
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);
@@ -126,6 +127,7 @@ struct CompileSpec : torch::CustomClassHolder {
126127
bool refit = false;
127128
bool debug = false;
128129
bool strict_types = false;
130+
bool truncate_long_and_double = false;
129131
Device device;
130132
EngineCapability capability = EngineCapability::kDEFAULT;
131133
int64_t num_min_timing_iters = 2;

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ PYBIND11_MODULE(_C, m) {
246246
.def_readwrite("num_min_timing_iters", &CompileSpec::num_min_timing_iters)
247247
.def_readwrite("num_avg_timing_iters", &CompileSpec::num_avg_timing_iters)
248248
.def_readwrite("workspace_size", &CompileSpec::workspace_size)
249-
.def_readwrite("max_batch_size", &CompileSpec::max_batch_size);
249+
.def_readwrite("max_batch_size", &CompileSpec::max_batch_size)
250+
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double);
250251

251252
py::class_<Device>(m, "Device")
252253
.def(py::init<>())

0 commit comments

Comments
 (0)