Skip to content

Commit bf1b2d8

Browse files
committed
feat(//py): Gate partial compilation from to_backend API
We cant run partial compilation on modules from the to_backend API because we are expected to simply return a handle to a TRT engine vs return a full graph. Therefore we cannot do graph stitching. Now an exception will be thrown if someone tries to use fallback and to_backend directing them towards trtorch.compile Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0a3258d commit bf1b2d8

File tree

5 files changed

+107
-72
lines changed

5 files changed

+107
-72
lines changed

py/trtorch/_compile_spec.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -259,37 +259,40 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
259259
backend_spec = torch.classes.tensorrt.CompileSpec()
260260

261261
for i in parsed_spec.input_ranges:
262-
ir = torch.classes.tensorrt.InputRange()
263-
ir.set_min(i.min)
264-
ir.set_opt(i.opt)
265-
ir.set_max(i.max)
266-
backend_spec.append_input_range(ir)
267-
268-
d = torch.classes.tensorrt.Device()
269-
d.set_device_type(int(parsed_spec.device.device_type))
270-
d.set_gpu_id(parsed_spec.device.gpu_id)
271-
d.set_dla_core(parsed_spec.device.dla_core)
272-
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
273-
274-
torch_fallback = torch.classes.tensorrt.TorchFallback()
275-
torch_fallback.set_enabled(parsed_spec.torch_fallback.enabled)
276-
torch_fallback.set_min_block_size(parsed_spec.torch_fallback.min_block_size)
277-
torch_fallback.set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)
278-
279-
backend_spec.set_device(d)
280-
backend_spec.set_torch_fallback(fallback)
281-
backend_spec.set_op_precision(int(parsed_spec.op_precision))
282-
backend_spec.set_disable_tf32(parsed_spec.disable_tf32)
283-
backend_spec.set_refit(parsed_spec.refit)
284-
backend_spec.set_debug(parsed_spec.debug)
285-
backend_spec.set_refit(parsed_spec.refit)
286-
backend_spec.set_strict_types(parsed_spec.strict_types)
287-
backend_spec.set_capability(int(parsed_spec.capability))
288-
backend_spec.set_num_min_timing_iters(parsed_spec.num_min_timing_iters)
289-
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
290-
backend_spec.set_workspace_size(parsed_spec.workspace_size)
291-
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
292-
backend_spec.set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
262+
ir = torch.classes.tensorrt._InputRange()
263+
ir._set_min(i.min)
264+
ir._set_opt(i.opt)
265+
ir._set_max(i.max)
266+
backend_spec._append_input_range(ir)
267+
268+
d = torch.classes.tensorrt._Device()
269+
d._set_device_type(int(parsed_spec.device.device_type))
270+
d._set_gpu_id(parsed_spec.device.gpu_id)
271+
d._set_dla_core(parsed_spec.device.dla_core)
272+
d._set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
273+
274+
if parsed_spec.torch_fallback.enabled:
275+
raise RuntimeError("Partial module compilation is not currently supported via the PyTorch to_backend API integration. If you need partial compilation, use trtorch.compile")
276+
277+
torch_fallback = torch.classes.tensorrt._TorchFallback()
278+
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)
279+
torch_fallback._set_min_block_size(parsed_spec.torch_fallback.min_block_size)
280+
torch_fallback._set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)
281+
282+
backend_spec._set_device(d)
283+
backend_spec._set_torch_fallback(torch_fallback)
284+
backend_spec._set_op_precision(int(parsed_spec.op_precision))
285+
backend_spec._set_disable_tf32(parsed_spec.disable_tf32)
286+
backend_spec._set_refit(parsed_spec.refit)
287+
backend_spec._set_debug(parsed_spec.debug)
288+
backend_spec._set_refit(parsed_spec.refit)
289+
backend_spec._set_strict_types(parsed_spec.strict_types)
290+
backend_spec._set_capability(int(parsed_spec.capability))
291+
backend_spec._set_num_min_timing_iters(parsed_spec.num_min_timing_iters)
292+
backend_spec._set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
293+
backend_spec._set_workspace_size(parsed_spec.workspace_size)
294+
backend_spec._set_max_batch_size(parsed_spec.max_batch_size)
295+
backend_spec._set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
293296
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())
294297

295298
return backend_spec

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,44 @@ namespace backend {
55
namespace {
66

77
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
8-
(registry).def("set_" #field_name, &class_name::set_##field_name); \
9-
(registry).def("get_" #field_name, &class_name::get_##field_name);
8+
(registry).def("_set_" #field_name, &class_name::set_##field_name); \
9+
(registry).def("_get_" #field_name, &class_name::get_##field_name);
1010

1111
void RegisterTRTCompileSpec() {
1212
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
13-
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());
13+
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
14+
.def(torch::init<>())
15+
.def("__str__", &trtorch::pyapi::InputRange::to_str);
1416

1517
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
1618
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
1719
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);
1820

1921
static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
20-
torch::class_<trtorch::pyapi::Device>("tensorrt", "Device").def(torch::init<>());
22+
torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
23+
.def(torch::init<>())
24+
.def("__str__", &trtorch::pyapi::Device::to_str);
2125

2226
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
2327
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
2428
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
2529
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
2630

2731
static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
28-
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "Fallback").def(torch::init<>());
32+
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
33+
.def(torch::init<>())
34+
.def("__str__", &trtorch::pyapi::TorchFallback::to_str);
35+
2936
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
3037
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
3138
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);
3239

3340
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
3441
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
3542
.def(torch::init<>())
36-
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
37-
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
38-
.def("set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
43+
.def("_append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
44+
.def("_set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
45+
.def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
3946
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
4047
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
4148

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
namespace trtorch {
55
namespace pyapi {
66

7-
std::string to_str(InputRange& value) {
7+
std::string InputRange::to_str() {
88
auto vec_to_str = [](std::vector<int64_t> shape) -> std::string {
99
std::stringstream ss;
1010
ss << '[';
@@ -17,9 +17,9 @@ std::string to_str(InputRange& value) {
1717

1818
std::stringstream ss;
1919
ss << " {" << std::endl;
20-
ss << " min: " << vec_to_str(value.min) << ',' << std::endl;
21-
ss << " opt: " << vec_to_str(value.opt) << ',' << std::endl;
22-
ss << " max: " << vec_to_str(value.max) << ',' << std::endl;
20+
ss << " min: " << vec_to_str(min) << ',' << std::endl;
21+
ss << " opt: " << vec_to_str(opt) << ',' << std::endl;
22+
ss << " max: " << vec_to_str(max) << ',' << std::endl;
2323
ss << " }" << std::endl;
2424
return ss.str();
2525
}
@@ -68,6 +68,18 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value) {
6868
}
6969
}
7070

71+
std::string Device::to_str() {
72+
std::stringstream ss;
73+
std::string fallback = allow_gpu_fallback ? "True" : "False";
74+
ss << " {" << std::endl;
75+
ss << " \"device_type\": " << pyapi::to_str(device_type) << std::endl;
76+
ss << " \"allow_gpu_fallback\": " << fallback << std::endl;
77+
ss << " \"gpu_id\": " << gpu_id << std::endl;
78+
ss << " \"dla_core\": " << dla_core << std::endl;
79+
ss << " }" << std::endl;
80+
return ss.str();
81+
}
82+
7183
std::string to_str(EngineCapability value) {
7284
switch (value) {
7385
case EngineCapability::kSAFE_GPU:
@@ -92,6 +104,21 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
92104
}
93105
}
94106

107+
std::string TorchFallback::to_str() {
108+
std::stringstream ss;
109+
std::string e = enabled ? "True" : "False";
110+
ss << " {" << std::endl;
111+
ss << " \"enabled\": " << e << std::endl;
112+
ss << " \"min_block_size\": " << min_block_size << std::endl;
113+
ss << " \"forced_fallback_operators\": [" << std::endl;
114+
for (auto i : forced_fallback_operators) {
115+
ss << " " << i << ',' << std::endl;
116+
}
117+
ss << " ]" << std::endl;
118+
ss << " }" << std::endl;
119+
return ss.str();
120+
}
121+
95122
core::CompileSpec CompileSpec::toInternalCompileSpec() {
96123
std::vector<core::ir::InputRange> internal_input_ranges;
97124
for (auto i : input_ranges) {
@@ -128,36 +155,25 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
128155
std::string CompileSpec::stringify() {
129156
std::stringstream ss;
130157
ss << "TensorRT Compile Spec: {" << std::endl;
131-
ss << " \"Input Shapes\": [" << std::endl;
158+
ss << " \"Input Shapes\": [" << std::endl;
132159
for (auto i : input_ranges) {
133-
ss << to_str(i);
160+
ss << i.to_str();
134161
}
135162
std::string enabled = torch_fallback.enabled ? "True" : "False";
136-
ss << " ]" << std::endl;
137-
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
138-
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
139-
ss << " \"Refit\": " << refit << std::endl;
140-
ss << " \"Debug\": " << debug << std::endl;
141-
ss << " \"Strict Types\": " << strict_types << std::endl;
142-
ss << " \"Device Type: " << to_str(device.device_type) << std::endl;
143-
ss << " \"GPU ID: " << device.gpu_id << std::endl;
144-
ss << " \"DLA Core: " << device.dla_core << std::endl;
145-
ss << " \"Allow GPU Fallback\": " << device.allow_gpu_fallback << std::endl;
146-
ss << " \"Engine Capability\": " << to_str(capability) << std::endl;
147-
ss << " \"Num Min Timing Iters\": " << num_min_timing_iters << std::endl;
148-
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
149-
ss << " \"Workspace Size\": " << workspace_size << std::endl;
150-
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
151-
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
152-
ss << " \"Torch Fallback: {" << std::endl;
153-
ss << " \"enabled\": " << enabled << std::endl;
154-
ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl;
155-
ss << " \"forced_fallback_operators\": [" << std::endl;
156-
for (auto i : torch_fallback.forced_fallback_operators) {
157-
ss << " " << i << ',' << std::endl;
158-
}
159-
ss << " ]" << std::endl;
160-
ss << " }" << std::endl;
163+
ss << " ]" << std::endl;
164+
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
165+
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
166+
ss << " \"Refit\": " << refit << std::endl;
167+
ss << " \"Debug\": " << debug << std::endl;
168+
ss << " \"Strict Types\": " << strict_types << std::endl;
169+
ss << " \"Device\": " << device.to_str() << std::endl;
170+
ss << " \"Engine Capability\": " << to_str(capability) << std::endl;
171+
ss << " \"Num Min Timing Iters\": " << num_min_timing_iters << std::endl;
172+
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
173+
ss << " \"Workspace Size\": " << workspace_size << std::endl;
174+
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
175+
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
176+
ss << " \"Torch Fallback\": " << torch_fallback.to_str();
161177
ss << "}";
162178
return ss.str();
163179
}

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ struct InputRange : torch::CustomClassHolder {
3939
ADD_FIELD_GET_SET(min, std::vector<int64_t>);
4040
ADD_FIELD_GET_SET(opt, std::vector<int64_t>);
4141
ADD_FIELD_GET_SET(max, std::vector<int64_t>);
42-
};
4342

44-
std::string to_str(InputRange& value);
43+
std::string to_str();
44+
};
4545

4646
enum class DataType : int8_t {
4747
kFloat,
@@ -73,6 +73,8 @@ struct Device : torch::CustomClassHolder {
7373
ADD_FIELD_GET_SET(gpu_id, int64_t);
7474
ADD_FIELD_GET_SET(dla_core, int64_t);
7575
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
76+
77+
std::string to_str();
7678
};
7779

7880
std::string to_str(DeviceType value);
@@ -87,8 +89,11 @@ struct TorchFallback : torch::CustomClassHolder {
8789
ADD_FIELD_GET_SET(enabled, bool);
8890
ADD_FIELD_GET_SET(min_block_size, int64_t);
8991
ADD_FIELD_GET_SET(forced_fallback_operators, std::vector<std::string>);
92+
93+
std::string to_str();
9094
};
9195

96+
9297
enum class EngineCapability : int8_t {
9398
kDEFAULT,
9499
kSAFE_GPU,

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ void log(core::util::logging::LogLevel lvl, const std::string& msg) {
165165
PYBIND11_MODULE(_C, m) {
166166
py::class_<InputRange>(m, "InputRange")
167167
.def(py::init<>())
168+
.def("__str__", &trtorch::pyapi::InputRange::to_str)
168169
.def_readwrite("min", &InputRange::min)
169170
.def_readwrite("opt", &InputRange::opt)
170171
.def_readwrite("max", &InputRange::max);
@@ -237,6 +238,7 @@ PYBIND11_MODULE(_C, m) {
237238

238239
py::class_<CompileSpec>(m, "CompileSpec")
239240
.def(py::init<>())
241+
.def("__str__", &trtorch::pyapi::CompileSpec::stringify)
240242
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
241243
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
242244
.def_readwrite("op_precision", &CompileSpec::op_precision)
@@ -256,13 +258,15 @@ PYBIND11_MODULE(_C, m) {
256258

257259
py::class_<Device>(m, "Device")
258260
.def(py::init<>())
261+
.def("__str__", &trtorch::pyapi::Device::to_str)
259262
.def_readwrite("device_type", &Device::device_type)
260263
.def_readwrite("gpu_id", &Device::gpu_id)
261264
.def_readwrite("dla_core", &Device::dla_core)
262265
.def_readwrite("allow_gpu_fallback", &Device::allow_gpu_fallback);
263266

264267
py::class_<TorchFallback>(m, "TorchFallback")
265268
.def(py::init<>())
269+
.def("__str__", &trtorch::pyapi::TorchFallback::to_str)
266270
.def_readwrite("enabled", &TorchFallback::enabled)
267271
.def_readwrite("min_block_size", &TorchFallback::min_block_size)
268272
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators);

0 commit comments

Comments
 (0)