Skip to content

Commit 482ef2c

Browse files
committed
feat(//py): Working portable package
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a71bca9 commit 482ef2c

File tree

6 files changed

+27
-28
lines changed

6 files changed

+27
-28
lines changed

py/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def run(self):
121121
extra_link_args=[
122122
"-D_GLIBCXX_USE_CXX11_ABI=0"
123123
"-Wl,--no-as-needed",
124-
"-ltrtorch"
124+
"-ltrtorch",
125+
"-Wl,-rpath,$ORIGIN/lib"
125126
],
126127
undef_macros=[ "NDEBUG" ]
127128
)

py/trtorch/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@
77
import ctypes
88
import torch
99

10-
def _load_trtorch_lib():
11-
lib_name = 'libtrtorch.so'
12-
here = os.path.abspath(__file__)
13-
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)
14-
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
15-
16-
_load_trtorch_lib()
17-
1810
from trtorch._version import __version__
1911
from trtorch._compiler import *
2012
from trtorch._types import *

py/trtorch/_compiler.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from typing import List, Dict, Any
22
import torch
3+
from torch import nn
4+
35
import trtorch._C
46
from trtorch._extra_info import _parse_extra_info
57
from trtorch._version import __version__
8+
from types import FunctionType
9+
610

711
def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule:
812
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -50,7 +54,11 @@ def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.Script
5054
Returns:
5155
torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT
5256
"""
53-
compiled_cpp_mod = trtorch._C._compile_graph(module._c, _parse_extra_info(extra_info))
57+
58+
if isinstance(module, torch.jit.ScriptFunction):
59+
raise TypeError("torch.jit.ScriptFunction currently is not directly supported, wrap the function in a module to compile")
60+
61+
compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_extra_info(extra_info))
5462
compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
5563
return compiled_module
5664

@@ -98,7 +106,10 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
98106
Returns:
99107
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
100108
"""
101-
return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))
109+
if isinstance(module, torch.jit.ScriptFunction):
110+
raise TypeError("torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile")
111+
112+
return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))
102113

103114
def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
104115
"""Checks to see if a method is fully supported by TRTorch
@@ -114,7 +125,7 @@ def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) ->
114125
Returns:
115126
bool: True if supported Method
116127
"""
117-
return trtorch._C._check_method_op_support(module._c, method_name)
128+
return trtorch._C.check_method_op_support(module._c, method_name)
118129

119130
def dump_build_info():
120131
"""Prints build information about the TRTorch distribution to stdout
@@ -127,7 +138,7 @@ def get_build_info() -> str:
127138
Returns:
128139
str: String containing the build information for TRTorch distribution
129140
"""
130-
build_info = trtorch._C._get_build_info()
141+
build_info = trtorch._C.get_build_info()
131142
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
132143
return build_info
133144

py/trtorch/_extra_info.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,12 @@ def _parse_device_type(device: Any) -> _types.DeviceType:
8484
else:
8585
raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device)))
8686

87-
def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:
88-
info = trtorch._C._ExtraInfo()
89-
if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list):
87+
def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C.ExtraInfo:
88+
info = trtorch._C.ExtraInfo()
89+
if "input_shapes" not in extra_info:
9090
raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict")
9191

9292
info.input_ranges = _parse_input_ranges(extra_info["input_shapes"])
93-
print(info.input_ranges)
9493

9594
if "op_precision" in extra_info:
9695
info.op_precision = _parse_op_precision(extra_info["op_precision"])

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ struct InputRange {
1818
std::vector<int64_t> max;
1919

2020
core::conversion::InputRange toInternalInputRange() {
21-
for (auto o : opt) {
22-
std::cout << o << std::endl;
23-
}
2421
return core::conversion::InputRange(min, opt, max);
2522
}
2623
};
@@ -79,7 +76,6 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
7976
struct ExtraInfo {
8077

8178
core::ExtraInfo toInternalExtraInfo() {
82-
std::cout << "HELLO" << input_ranges.size() << std::endl;
8379
for (auto i : input_ranges) {
8480
internal_input_ranges.push_back(i.toInternalInputRange());
8581
}
@@ -193,7 +189,7 @@ PYBIND11_MODULE(_C, m) {
193189
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
194190
.value("default", EngineCapability::kDEFAULT, "Use default behavior");
195191

196-
py::class_<ExtraInfo>(m, "_ExtraInfo")
192+
py::class_<ExtraInfo>(m, "ExtraInfo")
197193
.def(py::init<>())
198194
.def_readwrite("input_ranges", &ExtraInfo::input_ranges)
199195
.def_readwrite("op_precision", &ExtraInfo::op_precision)
@@ -209,10 +205,10 @@ PYBIND11_MODULE(_C, m) {
209205
.def_readwrite("max_batch_size", &ExtraInfo::max_batch_size);
210206

211207
m.doc() = "TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT";
212-
m.def("_compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded");
213-
m.def("_convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine");
214-
m.def("_check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators");
215-
m.def("_get_build_info", &get_build_info, "Returns build info about the compiler as a string");
208+
m.def("compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded");
209+
m.def("convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine");
210+
m.def("check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators");
211+
m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string");
216212

217213
m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output");
218214
m.def("_set_logging_prefix", &logging::set_logging_prefix, "Set the logging prefix for logging output");

py/trtorch/logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def set_logging_prefix(prefix: str):
4040
Args:
4141
prefix (str): Prefix to use for logging messages
4242
"""
43-
_set_logging_prefix(str)
43+
_set_logging_prefix(prefix)
4444

4545
def get_reportable_log_level() -> Level:
4646
"""Get the level required for a message to be printed in the log
@@ -84,4 +84,4 @@ def log(level: Level, msg: str):
8484
level (trtorch.logging.Level): Severity of the message
8585
msg (str): Actual message text
8686
"""
87-
_log(level, msg)
87+
_log(Level._to_internal_level(level), msg)

0 commit comments

Comments
 (0)