Skip to content

Commit 72bc1f7

Browse files
committed
feat(//py): API now produces valid engines that are consumable by
TensorRT Python API Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c1de126 commit 72bc1f7

File tree

3 files changed

+31
-21
lines changed

3 files changed

+31
-21
lines changed

py/setup.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
import glob
34
import setuptools
45
from setuptools import setup, Extension, find_packages
56
from setuptools.command.build_ext import build_ext
@@ -8,7 +9,7 @@
89
from distutils.cmd import Command
910

1011
from torch.utils import cpp_extension
11-
from shutil import copyfile
12+
from shutil import copyfile, rmtree
1213

1314
dir_path = os.path.dirname(os.path.realpath(__file__))
1415

@@ -60,7 +61,7 @@ def run(self):
6061

6162
class CleanCommand(Command):
6263
"""Custom clean command to tidy up the project root."""
63-
PY_CLEAN_FILES = ['./build', './dist', './trtorch/__pycache__', './*.pyc', './*.tgz', './*.egg-info']
64+
PY_CLEAN_FILES = ['./build', './dist', './trtorch/__pycache__', './trtorch/lib', './*.pyc', './*.tgz', './*.egg-info']
6465
description = "Command to tidy up the project root"
6566
user_options = []
6667

@@ -75,11 +76,11 @@ def run(self):
7576
# Make paths absolute and relative to this path
7677
abs_paths = glob.glob(os.path.normpath(os.path.join(dir_path, path_spec)))
7778
for path in [str(p) for p in abs_paths]:
78-
if not path.startswith(root_dir):
79+
if not path.startswith(dir_path):
7980
# Die if path in CLEAN_FILES is absolute + outside this directory
80-
raise ValueError("%s is not a path inside %s" % (path, root_dir))
81+
raise ValueError("%s is not a path inside %s" % (path, dir_path))
8182
print('Removing %s' % os.path.relpath(path))
82-
shutil.rmtree(path)
83+
rmtree(path)
8384

8485
ext_modules = [
8586
cpp_extension.CUDAExtension('trtorch._C',

py/trtorch/compiler.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
1515
else:
1616
raise TypeError("Input sizes for inputs are required to be a List, tuple or torch.Size or a Dict of three sizes (min, opt, max), found type: " + str(type(input_size)))
1717

18-
def _parse_input_sizes(input_sizes: List) -> List:
18+
def _parse_input_ranges(input_sizes: List) -> List:
1919

2020
if any (not isinstance(i, dict) and not _supported_input_size_type(i) for i in input_sizes):
2121
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")
@@ -28,16 +28,14 @@ def _parse_input_sizes(input_sizes: List) -> List:
2828
in_range.min = i["min"]
2929
in_range.opt = i["opt"]
3030
in_range.max = i["max"]
31-
32-
parsed_input_sizes.append(in_range.to_internal_input_range())
31+
parsed_input_sizes.append(in_range)
3332

3433
elif "opt" in i:
3534
in_range = trtorch._C.InputRange()
3635
in_range.min = i["opt"]
3736
in_range.opt = i["opt"]
3837
in_range.max = i["opt"]
39-
40-
parsed_input_sizes.append(in_range.to_internal_input_range())
38+
parsed_input_sizes.append(in_range)
4139

4240
else:
4341
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")
@@ -47,8 +45,14 @@ def _parse_input_sizes(input_sizes: List) -> List:
4745
in_range.min = i
4846
in_range.opt = i
4947
in_range.max = i
48+
parsed_input_sizes.append(in_range)
5049

51-
parsed_input_sizes.append(in_range.to_internal_input_range())
50+
elif isinstance(i, tuple):
51+
in_range = trtorch._C.InputRange()
52+
in_range.min = list(i)
53+
in_range.opt = list(i)
54+
in_range.max = list(i)
55+
parsed_input_sizes.append(in_range)
5256

5357
return parsed_input_sizes
5458

@@ -87,7 +91,8 @@ def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:
8791
if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list):
8892
raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict")
8993

90-
info.input_ranges = _parse_input_sizes(extra_info["input_shapes"])
94+
info.input_ranges = _parse_input_ranges(extra_info["input_shapes"])
95+
print(info.input_ranges)
9196

9297
if "op_precision" in extra_info:
9398
info.op_precision = _parse_op_precision(extra_info["op_precision"])

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ struct InputRange {
1818
std::vector<int64_t> max;
1919

2020
core::conversion::InputRange toInternalInputRange() {
21+
for (auto o : opt) {
22+
std::cout << o << std::endl;
23+
}
2124
return core::conversion::InputRange(min, opt, max);
2225
}
2326
};
@@ -76,7 +79,11 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
7679
struct ExtraInfo {
7780

7881
core::ExtraInfo toInternalExtraInfo() {
79-
auto info = core::ExtraInfo(input_ranges);
82+
std::cout << "HELLO" << input_ranges.size() << std::endl;
83+
for (auto i : input_ranges) {
84+
internal_input_ranges.push_back(i.toInternalInputRange());
85+
}
86+
auto info = core::ExtraInfo(internal_input_ranges);
8087
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
8188
info.convert_info.engine_settings.refit = refit;
8289
info.convert_info.engine_settings.debug = debug;
@@ -91,7 +98,8 @@ struct ExtraInfo {
9198
return info;
9299
}
93100

94-
std::vector<core::conversion::InputRange> input_ranges;
101+
std::vector<InputRange> input_ranges;
102+
std::vector<core::conversion::InputRange> internal_input_ranges;
95103
DataType op_precision = DataType::kFloat;
96104
bool refit = false;
97105
bool debug = false;
@@ -112,10 +120,10 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, ExtraInfo& info)
112120
return trt_mod;
113121
}
114122

115-
std::string ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, ExtraInfo& info) {
123+
py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, ExtraInfo& info) {
116124
py::gil_scoped_acquire gil;
117125
auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalExtraInfo());
118-
return trt_engine;
126+
return py::bytes(trt_engine);
119127
}
120128

121129
bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::string& method_name) {
@@ -136,11 +144,7 @@ PYBIND11_MODULE(_C, m) {
136144
.def(py::init<>())
137145
.def_readwrite("min", &InputRange::min)
138146
.def_readwrite("opt", &InputRange::opt)
139-
.def_readwrite("max", &InputRange::max)
140-
.def("_to_internal_input_range", &InputRange::toInternalInputRange);
141-
142-
//py::class_<core::conversion::InputRange>(m, "_InternalInputRange")
143-
// .def(py::init<>());
147+
.def_readwrite("max", &InputRange::max);
144148

145149
py::enum_<DataType>(m, "dtype")
146150
.value("float", DataType::kFloat)

0 commit comments

Comments
 (0)