diff --git a/BUILD b/BUILD
index 48f2ae5338..02829531b2 100644
--- a/BUILD
+++ b/BUILD
@@ -8,6 +8,8 @@ pkg_tar(
"//core/conversion:include",
"//core/conversion/conversionctx:include",
"//core/conversion/converters:include",
+ "//core/conversion/var:include",
+ "//core/conversion/tensorcontainer:include",
"//core/conversion/evaluators:include",
"//core/execution:include",
"//core/lowering:include",
@@ -35,6 +37,15 @@ pkg_tar(
)
+pkg_tar(
+ name = "bin",
+ package_dir = "bin/",
+ srcs = [
+ "//cpp/trtorchc:trtorchc",
+ ],
+ mode = "0755",
+)
+
pkg_tar(
@@ -46,6 +57,7 @@ pkg_tar(
],
deps = [
":lib",
+ ":bin",
":include",
":include_core",
],
diff --git a/README.md b/README.md
index befe86e8fe..60cfe55e94 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,8 @@ compile_settings.op_precision = torch::kFloat;
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
// Run like normal
auto results = trt_mod.forward({in_tensor});
+// Save module for later
+trt_mod.save("trt_torchscript_module.ts");
...
```
@@ -46,6 +48,7 @@ trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
input_data = input_data.half()
result = trt_ts_module(input_data)
+torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
```
> Notes on running in lower precisions:
diff --git a/core/compiler.cpp b/core/compiler.cpp
index 2f94ba8ead..be0dc895d8 100644
--- a/core/compiler.cpp
+++ b/core/compiler.cpp
@@ -6,7 +6,9 @@
#include "NvInfer.h"
#include "ATen/core/function_schema.h"
+#include "ATen/core/jit_type.h"
+#include "torch/custom_class.h"
#include "torch/csrc/jit/frontend/function_schema_parser.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/passes/pass_manager.h"
@@ -40,32 +42,70 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr& g, std::string& serialized_engine) {
- execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
- auto num_io = execution::GetEngineIO(uid);
-
- auto self = g->addInput("self.1");
+ auto engine = execution::TRTEngine(mod._ivalue()->name(), serialized_engine);
+ // Get required metadata about the engine out
+ auto num_io = engine.num_io;
+ auto name = engine.name;
+
+ // Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
+ auto engine_ptr = c10::make_intrusive(engine);
+ mod.register_attribute(
+ name,
+ c10::getCustomClassType>(),
+ c10::IValue(std::move(engine_ptr)),
+ false
+ );
+
+ // Add the module as an input into the graph
+ auto self = g->addInput("self_1");
self->setType(mod.type());
- auto id_val = g->insertConstant(uid);
+ // Start by retriveing the engine from the module attribute list
+ auto engine_node = g->createGetAttr(self, name);
+ g->block()->appendNode(engine_node);
+ // Add inputs to the graph corresponding to the number of input tensors expected by the engine
+ // Also store those inputs in a vector so that they can be coalesced into a single list at runtime
std::vector engine_inputs;
- engine_inputs.push_back(id_val);
-
for (uint64_t i = 0; i < num_io.first; i++) {
- auto in_val = g->addInput("");
+ auto in_val = g->addInput(std::string("input_") + std::to_string(i));
in_val->setType(c10::TensorType::get());
engine_inputs.push_back(in_val);
}
- auto engine_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef(engine_inputs), num_io.second);
- g->block()->appendNode(engine_node);
-
- if (engine_node->outputs().size() > 1) {
- auto return_tuple_node = g->createTuple(engine_node->outputs());
+ // Create a node that will merge all of the input tensors into a single list argument to the trt::execute_engine op
+ // Creates: prim::ListConstruct( )
+ auto input_list_node = g->createList(c10::TensorType::get(), torch::jit::ArrayRef(engine_inputs));
+ g->block()->appendNode(input_list_node);
+
+ // Make a list of inputs to the actual trt::execute_engine op
+ // Note: Ordering of list and then engine is because we can pop off the engine first which contains all the metadata
+ // needed for execution
+ std::vector execute_node_inputs;
+ execute_node_inputs.push_back(input_list_node->outputs()[0]);
+ execute_node_inputs.push_back(engine_node->outputs()[0]);
+
+ // Create the actual execution node trt::execute_engine using the assembled inputs
+ auto execute_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef(execute_node_inputs), 1);
+ g->block()->appendNode(execute_node);
+ execute_node->outputs()[0]->setType(c10::ListType::ofTensors());
+
+ // Create a node to unpack the list into seperate tensors, in the case of there being only one tensor, the tensor will be returned,
+ // otherwise they are returned as a tuple of tensors.
+ // Creates: prim::ListUnpack()
+ auto unpack_node = g->createListUnpack(execute_node->outputs()[0], num_io.second);
+ g->block()->appendNode(unpack_node);
+
+ // If there are multiple output tensors from TensorRT we wrap them in a tuple to return
+ if (unpack_node->outputs().size() > 1) {
+ // Creates prim::TupleConstruct() using outputs of the unpack node
+ auto return_tuple_node = g->createTuple(unpack_node->outputs());
g->block()->appendNode(return_tuple_node);
+ // Set the output as the produced tuple
g->registerOutput(return_tuple_node->outputs()[0]);
} else {
- g->registerOutput(engine_node->outputs()[0]);
+ // Set the output as the sole output tensor
+ g->registerOutput(unpack_node->outputs()[0]);
}
LOG_DEBUG(*g << "(AddEngineToGraph)\n");
diff --git a/core/conversion/InterfaceTypes.cpp b/core/conversion/InterfaceTypes.cpp
index ac90085583..3ec3d93178 100644
--- a/core/conversion/InterfaceTypes.cpp
+++ b/core/conversion/InterfaceTypes.cpp
@@ -34,7 +34,7 @@ InputRange::InputRange(std::vector d) {
min = util::toDims(d);
max = util::toDims(d);
input_shape = util::toDims(d);
-
+ input_is_dynamic = false;
}
@@ -67,6 +67,7 @@ InputRange::InputRange(std::vector min_shape, std::vector opt_
dim.insert(max_shape[i]);
if (dim.size() != 1) {
dyn_shape.push_back(-1);
+ input_is_dynamic = true;
} else {
dyn_shape.push_back(opt_shape[i]);
}
diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp
index 911e58e039..fc4e75ca88 100644
--- a/core/conversion/conversion.cpp
+++ b/core/conversion/conversion.cpp
@@ -155,6 +155,10 @@ void AddInputs(ConversionCtx* ctx,
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max);
+ if (dims.input_is_dynamic) {
+ ctx->input_is_dynamic = true;
+ }
+
ctx->value_tensor_map[in] = trt_in;
}
diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h
index 529d04f6b6..1c7a790025 100644
--- a/core/conversion/conversion.h
+++ b/core/conversion/conversion.h
@@ -15,6 +15,7 @@ struct InputRange {
nvinfer1::Dims max;
nvinfer1::Dims opt;
nvinfer1::Dims input_shape;
+ bool input_is_dynamic = false;
// Should we restrict to unsigned?
InputRange(std::vector d);
InputRange(std::vector min_shape,
diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h
index 76653037a9..abd49cf22e 100644
--- a/core/conversion/conversionctx/ConversionCtx.h
+++ b/core/conversion/conversionctx/ConversionCtx.h
@@ -42,6 +42,7 @@ struct ConversionCtx {
~ConversionCtx();
+ bool input_is_dynamic = false;
nvinfer1::IBuilder* builder;
nvinfer1::INetworkDefinition* net;
nvinfer1::IBuilderConfig* cfg;
diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp
index bd923310a0..a7b6045737 100644
--- a/core/conversion/converters/impl/batch_norm.cpp
+++ b/core/conversion/converters/impl/batch_norm.cpp
@@ -19,12 +19,24 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto options = torch::TensorOptions().dtype(torch::kFloat32);
- auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
- auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
- auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
- auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
+
+ torch::Tensor gamma, beta, mean, var;
+
+ if (ctx->input_is_dynamic) {
+ gamma = args[1].unwrapToTensor();
+ beta = args[2].unwrapToTensor();
+ mean = args[3].unwrapToTensor();
+ var = args[4].unwrapToTensor();
+ } else {
+ gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
+ beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
+ mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
+ var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
+ }
+
auto eps = args[7].unwrapToDouble(1e-5f);
+
LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");
diff --git a/core/conversion/converters/impl/concat.cpp b/core/conversion/converters/impl/concat.cpp
index da3853291c..2063d8921f 100644
--- a/core/conversion/converters/impl/concat.cpp
+++ b/core/conversion/converters/impl/concat.cpp
@@ -8,7 +8,7 @@ namespace conversion {
namespace converters {
namespace impl {
namespace {
-auto cat_registrations = RegisterNodeConversionPatterns()
+auto cat_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/converters/impl/constant.cpp b/core/conversion/converters/impl/constant.cpp
index 432eb6bf85..1c23cb6a8b 100644
--- a/core/conversion/converters/impl/constant.cpp
+++ b/core/conversion/converters/impl/constant.cpp
@@ -7,7 +7,7 @@ namespace conversion {
namespace converters {
namespace impl {
namespace {
-auto constant_registrations = RegisterNodeConversionPatterns()
+auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"trt::const(Tensor self) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp
index 37cf3ff3ad..3388a26741 100644
--- a/core/conversion/converters/impl/conv_deconv.cpp
+++ b/core/conversion/converters/impl/conv_deconv.cpp
@@ -9,7 +9,7 @@ namespace conversion {
namespace converters {
namespace impl {
namespace {
-auto conv_registrations = RegisterNodeConversionPatterns()
+auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
R"SIG(aten::_convolution(Tensor input, Tensor weight,
Tensor? bias, int[] stride, int[] padding,
diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp
index 375e7a2d8f..4cb2e03a19 100644
--- a/core/conversion/converters/impl/element_wise.cpp
+++ b/core/conversion/converters/impl/element_wise.cpp
@@ -68,7 +68,7 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera
}
-auto element_wise_registrations = RegisterNodeConversionPatterns()
+auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/converters/impl/linear.cpp b/core/conversion/converters/impl/linear.cpp
index f4c49ec020..e22664afe0 100644
--- a/core/conversion/converters/impl/linear.cpp
+++ b/core/conversion/converters/impl/linear.cpp
@@ -8,7 +8,7 @@ namespace converters {
namespace impl {
namespace {
-auto linear_registrations = RegisterNodeConversionPatterns()
+auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp
index c6d2d99f1e..cbebdc13b2 100644
--- a/core/conversion/converters/impl/matrix_multiply.cpp
+++ b/core/conversion/converters/impl/matrix_multiply.cpp
@@ -8,7 +8,7 @@ namespace converters {
namespace impl {
namespace {
-auto mm_registrations = RegisterNodeConversionPatterns()
+auto mm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/converters/impl/pooling.cpp b/core/conversion/converters/impl/pooling.cpp
index 04472ce5fc..e18c78c1ed 100644
--- a/core/conversion/converters/impl/pooling.cpp
+++ b/core/conversion/converters/impl/pooling.cpp
@@ -8,7 +8,7 @@ namespace converters {
namespace impl {
namespace {
-auto pooling_registrations = RegisterNodeConversionPatterns()
+auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp
index 0127f83285..16e0d9dd83 100644
--- a/core/conversion/converters/impl/reduce.cpp
+++ b/core/conversion/converters/impl/reduce.cpp
@@ -11,7 +11,7 @@ namespace {
-auto reduce_registrations = RegisterNodeConversionPatterns()
+auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/converters/impl/shape.cpp b/core/conversion/converters/impl/shape.cpp
index d5b3577a34..613ce43fe9 100644
--- a/core/conversion/converters/impl/shape.cpp
+++ b/core/conversion/converters/impl/shape.cpp
@@ -9,7 +9,7 @@ namespace converters {
namespace impl {
namespace {
-static auto shape_registrations = RegisterNodeConversionPatterns()
+static auto shape_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
// To use in static input size cases (explicit batch)
"aten::size.int(Tensor self, int dim) -> (Tensor)",
diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp
index ceda35a5d9..951635a8fc 100644
--- a/core/conversion/converters/impl/shuffle.cpp
+++ b/core/conversion/converters/impl/shuffle.cpp
@@ -9,7 +9,7 @@ namespace converters {
namespace impl {
namespace {
-static auto shuffle_registrations = RegisterNodeConversionPatterns()
+static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -50,12 +50,10 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
- auto ex_tensor = torch::rand(in_shape);
- auto new_shape = ex_tensor.view(args[1].unwrapToIntList().vec()).sizes();
auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
- shuffle->setReshapeDimensions(util::toDims(new_shape));
+ shuffle->setReshapeDimensions(util::toDims(args[1].unwrapToIntList().vec()));
shuffle->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
diff --git a/core/conversion/converters/impl/softmax.cpp b/core/conversion/converters/impl/softmax.cpp
index 35f6f04ef1..6a81b974a2 100644
--- a/core/conversion/converters/impl/softmax.cpp
+++ b/core/conversion/converters/impl/softmax.cpp
@@ -7,7 +7,7 @@ namespace converters {
namespace impl {
namespace {
-static auto softmax_registrations = RegisterNodeConversionPatterns()
+static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/core/conversion/tensorcontainer/TensorContainer.cpp b/core/conversion/tensorcontainer/TensorContainer.cpp
index 536d578eae..6fad66335d 100644
--- a/core/conversion/tensorcontainer/TensorContainer.cpp
+++ b/core/conversion/tensorcontainer/TensorContainer.cpp
@@ -6,7 +6,7 @@ namespace conversion {
namespace {
static auto tensor_container =
- torch::class_("_eval_ivalue_types", "TensorContainer")
+ torch::class_("_trtorch_eval_ivalue_types", "TensorContainer")
.def(torch::init<>());
} // namespace
} // conversion
diff --git a/core/conversion/var/BUILD b/core/conversion/var/BUILD
index e1c92efb12..247f939e48 100644
--- a/core/conversion/var/BUILD
+++ b/core/conversion/var/BUILD
@@ -30,7 +30,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
pkg_tar(
name = "include",
- package_dir = "core/conversion/arg/",
+ package_dir = "core/conversion/var/",
srcs = [
"Var.h",
"Var_inl.h"
diff --git a/core/execution/BUILD b/core/execution/BUILD
index 009092d3e6..1741249624 100644
--- a/core/execution/BUILD
+++ b/core/execution/BUILD
@@ -14,7 +14,6 @@ cc_library(
],
srcs = [
"TRTEngine.cpp",
- "TRTEngineManager.cpp",
"register_trt_op.cpp",
],
deps = [
diff --git a/core/execution/TRTEngine.cpp b/core/execution/TRTEngine.cpp
index 3370ea6f5b..3d4dbc8033 100644
--- a/core/execution/TRTEngine.cpp
+++ b/core/execution/TRTEngine.cpp
@@ -10,12 +10,32 @@ namespace trtorch {
namespace core {
namespace execution {
-TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine) {
+std::string slugify(std::string s) {
+ std::replace(s.begin(), s.end(), '.', '_');
+ return s;
+}
+
+TRTEngine::TRTEngine(std::string serialized_engine)
+ : logger(std::string("[] - "),
+ util::logging::get_logger().get_reportable_severity(),
+ util::logging::get_logger().get_is_colored_output_on()) {
+ std::string _name = "deserialized_trt";
+ new (this) TRTEngine(_name, serialized_engine);
+}
+
+TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
+ : logger(std::string("[") + mod_name + std::string("_engine] - "),
+ util::logging::get_logger().get_reportable_severity(),
+ util::logging::get_logger().get_is_colored_output_on()) {
+
rt = nvinfer1::createInferRuntime(logger);
+ name = slugify(mod_name) + "_engine";
+
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
// Easy way to get a unique name for each engine, maybe there is a more descriptive way (using something associated with the graph maybe)
id = reinterpret_cast(cuda_engine);
+
exec_ctx = cuda_engine->createExecutionContext();
uint64_t inputs = 0;
@@ -40,7 +60,28 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
return (*this);
}
+// TODO: Implement a call method
+// c10::List TRTEngine::Run(c10::List inputs) {
+// auto input_vec = inputs.vec();
+// auto output_vec = RunCudaEngine(exec_ctx, num_io, input_vec);
+//
+// return c10::List(output_vec);
+// }
+
+static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_("tensorrt", "Engine")
+ .def(torch::init())
+ // TODO: .def("__call__", &TRTEngine::Run)
+ // TODO: .def("run", &TRTEngine::Run)
+ .def_pickle(
+ [](const c10::intrusive_ptr& self) -> std::string {
+ auto serialized_engine = self->cuda_engine->serialize();
+ return std::string((const char*)serialized_engine->data(), serialized_engine->size());
+ },
+ [](std::string seralized_engine) -> c10::intrusive_ptr {
+ return c10::make_intrusive(std::move(seralized_engine));
+ }
+ );
+
} // namespace execution
} // namespace core
} // namespace trtorch
-
diff --git a/core/execution/TRTEngineManager.cpp b/core/execution/TRTEngineManager.cpp
deleted file mode 100644
index 27a6aeff28..0000000000
--- a/core/execution/TRTEngineManager.cpp
+++ /dev/null
@@ -1,82 +0,0 @@
-#include "core/util/prelude.h"
-#include "core/execution/execution.h"
-
-namespace trtorch {
-namespace core {
-namespace execution {
-namespace {
-class TRTEngineManager {
-public:
- TRTEngineManager()
- : logger_("[TRTorch Execution Manager] - ",
- util::logging::get_logger().get_reportable_severity(),
- util::logging::get_logger().get_is_colored_output_on()) {
- }
-
- TRTEngine* get_engine(EngineID uid) {
- auto iter = engine_registry_.find(uid);
-
- TRTORCH_ASSERT(iter != engine_registry_.end(), "Unabled to find requested engine (ID: " << uid << ") in TensorRT Execution Manager");
-
- return &(iter->second);
- }
-
- // TODO: Should we have standing engines ready to run or should we be creating execution contexts JIT?
- EngineID register_engine(std::string& serialized_engine) {
- auto engine = TRTEngine(logger_, serialized_engine);
- EngineID uid = engine.id;
- engine_registry_[uid] = std::move(engine);
- LOG_DEBUG(logger_, "Registering new engine (ID: " << std::hex << uid << ") in TensorRT Execution Manager");
- return uid;
- }
-
- void deregister_engine(EngineID uid) {
- auto iter = engine_registry_.find(uid);
- TRTORCH_ASSERT(iter != engine_registry_.end(), "Unabled to find requested engine (ID: " << uid << ") in TensorRT Execution Manager");
-
- auto engine = iter->second;
- // Doing this here since for some reason the destructor causes segfaults
- engine.exec_ctx->destroy();
- engine.cuda_engine->destroy();
- engine_registry_.erase(uid);
- }
-
-private:
- util::logging::TRTorchLogger logger_;
- std::unordered_map engine_registry_;
-};
-
-TRTEngineManager& get_engine_manager() {
- static TRTEngineManager engine_man;
- return engine_man;
-}
-} // namespace
-
-uint64_t RegisterEngineFromSerializedEngine(std::string& serialized_engine) {
- return get_engine_manager().register_engine(serialized_engine);
-}
-
-nvinfer1::ICudaEngine* GetCudaEngine(EngineID id) {
- // Assuming exception will be thrown inside the manager if there is no corresponding engine
- return get_engine_manager().get_engine(id)->cuda_engine;
-}
-
-nvinfer1::IExecutionContext* GetExecCtx(EngineID id) {
- // Assuming exception will be thrown inside the manager if there is no corresponding engine
- return get_engine_manager().get_engine(id)->exec_ctx;
-}
-
-std::pair GetEngineIO(EngineID id) {
- // Assuming exception will be thrown inside the manager if there is no corresponding engine
- return get_engine_manager().get_engine(id)->num_io;
-}
-
-void DeregisterEngine(EngineID id) {
- get_engine_manager().deregister_engine(id);
-}
-
-} // namespace execution
-} // namespace core
-} // namespace trtorch
-
-
diff --git a/core/execution/execution.h b/core/execution/execution.h
index 8c50dd4207..9b0ca41cb4 100644
--- a/core/execution/execution.h
+++ b/core/execution/execution.h
@@ -2,6 +2,9 @@
#include
#include "NvInfer.h"
#include "ATen/core/function_schema.h"
+#include "torch/custom_class.h"
+#include "core/util/prelude.h"
+
namespace trtorch {
namespace core {
@@ -9,25 +12,25 @@ namespace execution {
using EngineID = int64_t;
-struct TRTEngine {
+struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
nvinfer1::IRuntime* rt;
nvinfer1::ICudaEngine* cuda_engine;
nvinfer1::IExecutionContext* exec_ctx;
std::pair num_io;
EngineID id;
+ std::string name;
+ util::logging::TRTorchLogger logger;
TRTEngine() = default;
- TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine);
+ TRTEngine(std::string serialized_engine);
+ TRTEngine(std::string mod_name, std::string serialized_engine);
TRTEngine& operator=(const TRTEngine& other);
+ // TODO: Implement a call method
+ //c10::List Run(c10::List inputs);
};
-void RegisterEngineOp(TRTEngine& engine);
-uint64_t RegisterEngineFromSerializedEngine(std::string& serialized_engine);
-nvinfer1::ICudaEngine* GetCudaEngine(EngineID id);
-nvinfer1::IExecutionContext* GetExecCtx(EngineID id);
-std::pair GetEngineIO(EngineID id);
-void DeregisterEngine(EngineID id);
+std::vector RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pair io, std::vector& inputs);
} // namespace execution
} // namespace core
diff --git a/core/execution/register_trt_op.cpp b/core/execution/register_trt_op.cpp
index d9f57452dc..b7c10912be 100644
--- a/core/execution/register_trt_op.cpp
+++ b/core/execution/register_trt_op.cpp
@@ -9,7 +9,6 @@
namespace trtorch {
namespace core {
namespace execution {
-namespace {
std::vector RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pair io, std::vector& inputs) {
std::vector gpu_handles;
@@ -47,6 +46,7 @@ std::vector RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
return outputs;
}
+namespace {
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
@@ -54,27 +54,19 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
// Switched to a global operator because op implementations need to be non-capturing lambdas in PYT 1.5.0+
torch::jit::RegisterOperators jit_registry({
torch::jit::Operator(
- "trt::execute_engine(int id, ...) -> ...",
+ "trt::execute_engine(Tensor[] inputs, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]",
[](torch::jit::Stack& stack) -> int {
- size_t num_inputs = torch::jit::pop(stack).toInt();
// Verify calling convention (right to left or left to right)
- std::vector inputs;
- for (uint64_t i = 0; i < num_inputs - 1; i++) {
- at::Tensor in;
- torch::jit::pop(stack, in);
- inputs.insert(inputs.begin(), std::move(in));
- }
+ auto engine = torch::jit::pop(stack).toCustomClass();
+ LOG_DEBUG("Attempting to run engine (ID: " << std::hex << engine->name << ")");
+
+ auto inputs = torch::jit::pop(stack).toTensorVector();
- int64_t id = torch::jit::pop(stack).toInt();
- LOG_DEBUG("Attempting to run engine (ID: " << std::hex << id << ")");
- auto io = GetEngineIO(id);
- auto num_out = io.second;
+ auto io = engine->num_io;
- auto ctx = GetExecCtx(id);
+ auto ctx = engine->exec_ctx;
auto outputs = RunCudaEngine(ctx, io, inputs);
- for (uint64_t o = 0; o < num_out; o++) {
- torch::jit::push(stack, std::move(outputs[o]));
- }
+ torch::jit::push(stack, std::move(outputs));
return 0;
},
aliasAnalysisFromSchema())
diff --git a/cpp/api/include/trtorch/ptq.h b/cpp/api/include/trtorch/ptq.h
index afae26a85c..ce59395b4c 100644
--- a/cpp/api/include/trtorch/ptq.h
+++ b/cpp/api/include/trtorch/ptq.h
@@ -104,18 +104,17 @@ class Int8Calibrator : Algorithm {
std::stringstream ss;
ss << "Reading Calibration Cache from " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
+
cache_.clear();
- std::ifstream cache_file(cache_file_path_, std::ios::binary);
- cache_file >> std::noskipws;
- if (cache_file.good()) {
- std::copy(std::istream_iterator(cache_file),
- std::istream_iterator(),
- std::back_inserter(cache_));
- ss << "Cache read";
- logging::log(logging::Level::kDEBUG, ss.str());
+ std::ifstream input(cache_file_path_, std::ios::binary);
+ input >> std::noskipws;
+ if (input.good()) {
+ std::copy(std::istream_iterator(input), std::istream_iterator(),
+ std::back_inserter(cache_));
+ logging::log(logging::Level::kDEBUG, "Cache read");
}
- cache_size_ = cache_.size();
- return cache_size_ ? cache_.data() : nullptr;
+ length = cache_.size();
+ return length ? cache_.data() : nullptr;
}
return nullptr;
}
@@ -220,23 +219,17 @@ class Int8CacheCalibrator : Algorithm {
std::stringstream ss;
ss << "Reading Calibration Cache from " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
+
cache_.clear();
- std::ifstream cache_file;
- cache_file.open(cache_file_path_, std::ios::in | std::ios::binary);
- cache_file.unsetf(std::ios::skipws);
- cache_file.seekg(0, std::ios::beg);
- cache_.reserve(cache_file.tellg());
- cache_file.seekg(0, std::ios::beg);
- if (cache_file.good()) {
- std::cout << "Trying to read cache" << std::endl;
- std::copy(std::istreambuf_iterator(cache_file),
- std::istreambuf_iterator(),
- std::back_inserter(cache_));
- ss << "Cache read";
- logging::log(logging::Level::kDEBUG, ss.str());
+ std::ifstream input(cache_file_path_, std::ios::binary);
+ input >> std::noskipws;
+ if (input.good()) {
+ std::copy(std::istream_iterator(input), std::istream_iterator(),
+ std::back_inserter(cache_));
+ logging::log(logging::Level::kDEBUG, "Cache read");
}
- cache_size_ = cache_.size();
- return cache_size_ ? cache_.data() : nullptr;
+ length = cache_.size();
+ return length ? cache_.data() : nullptr;
}
diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h
index 9b3f98e355..8f26e0bd8f 100644
--- a/cpp/api/include/trtorch/trtorch.h
+++ b/cpp/api/include/trtorch/trtorch.h
@@ -142,6 +142,14 @@ struct TRTORCH_API ExtraInfo {
* @return false
*/
constexpr bool operator==(DataType other) const { return value == other.value; }
+ /**
+ * @brief Comparision operator for DataType
+ *
+ * @param other
+ * @return true
+ * @return false
+ */
+ constexpr bool operator==(DataType::Value other) const { return value == other; }
/**
* @brief Comparision operator for DataType
*
@@ -150,6 +158,14 @@ struct TRTORCH_API ExtraInfo {
* @return false
*/
constexpr bool operator!=(DataType other) const { return value != other.value; }
+ /**
+ * @brief Comparision operator for DataType
+ *
+ * @param other
+ * @return true
+ * @return false
+ */
+ constexpr bool operator!=(DataType::Value other) const { return value != other; }
private:
Value value;
};
diff --git a/cpp/trtorchc/BUILD b/cpp/trtorchc/BUILD
new file mode 100644
index 0000000000..7fa89836f5
--- /dev/null
+++ b/cpp/trtorchc/BUILD
@@ -0,0 +1,14 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_binary(
+ name = "trtorchc",
+ srcs = [
+ "main.cpp"
+ ],
+ deps = [
+ "@libtorch//:libtorch",
+ "@libtorch//:caffe2",
+ "//third_party/args",
+ "//cpp/api:trtorch"
+ ],
+)
diff --git a/cpp/trtorchc/README.md b/cpp/trtorchc/README.md
new file mode 100644
index 0000000000..25a59efb27
--- /dev/null
+++ b/cpp/trtorchc/README.md
@@ -0,0 +1,87 @@
+# trtorhc
+
+trtorchc is a compiler CLI application using the TRTorch compiler. It serves as an easy way to compile a
+TorchScript Module with TRTorch from the command-line to quickly check support or as part of
+a deployment pipeline. All basic features of the compiler are supported including post training
+quantization (though you must already have a calibration cache file to use). The compiler can
+output two formats, either a TorchScript program with the TensorRT engine embedded or
+the TensorRT engine itself as a PLAN file.
+
+All that is required to run the program after compilation is for C++ linking against libtrtorch.so
+or in Python importing the trtorch package. All other aspects of using compiled modules are identical
+to standard TorchScript. Load with `torch.jit.load()` and run like you would run any other module.
+
+
+```
+trtorchc [input_file_path] [output_file_path]
+ [input_shapes...] {OPTIONS}
+
+ TRTorch is a compiler for TorchScript, it will compile and optimize
+ TorchScript programs to run on NVIDIA GPUs using TensorRT
+
+ OPTIONS:
+
+ -h, --help Display this help menu
+ Verbiosity of the compiler
+ -v, --verbose Dumps debugging information about the
+ compilation process onto the console
+ -w, --warnings Disables warnings generated during
+ compilation onto the console (warnings
+ are on by default)
+ --info Dumps info messages generated during
+ compilation onto the console
+ --build-debuggable-engine Creates a debuggable engine
+ --use-strict-types Restrict operating type to only use set
+ default operation precision
+ (op_precision)
+ --allow-gpu-fallback (Only used when targeting DLA
+ (device-type)) Lets engine run layers on
+ GPU if they are not supported on DLA
+ -p[precision],
+ --default-op-precision=[precision]
+ Default operating precision for the
+ engine (Int8 requires a
+ calibration-cache argument) [ float |
+ float32 | f32 | half | float16 | f16 |
+ int8 | i8 ] (default: float)
+ -d[type], --device-type=[type] The type of device the engine should be
+ built for [ gpu | dla ] (default: gpu)
+ --engine-capability=[capability] The type of device the engine should be
+ built for [ default | safe_gpu |
+ safe_dla ]
+ --calibration-cache-file=[file_path]
+ Path to calibration cache file to use
+ for post training quantization
+ --num-min-timing-iter=[num_iters] Number of minimization timing iterations
+ used to select kernels
+ --num-avg-timing-iters=[num_iters]
+ Number of averaging timing iterations
+ used to select kernels
+ --workspace-size=[workspace_size] Maximum size of workspace given to
+ TensorRT
+ --max-batch-size=[max_batch_size] Maximum batch size (must be >= 1 to be
+ set, 0 means not set)
+ -t[threshold],
+ --threshold=[threshold] Maximum acceptable numerical deviation
+ from standard torchscript output
+ (default 2e-5)
+ --save-engine Instead of compiling a full a
+ TorchScript program, save the created
+ engine to the path specified as the
+ output path
+ input_file_path Path to input TorchScript file
+ output_file_path Path for compiled TorchScript (or
+ TensorRT engine) file
+ input_shapes... Sizes for inputs to engine, can either
+ be a single size or a range defined by
+ Min, Optimal, Max sizes, e.g.
+ "(N,..,C,H,W)"
+ "[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]"
+ "--" can be used to terminate flag options and force all following
+ arguments to be treated as positional options
+```
+
+e.g.
+```
+trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16
+```
\ No newline at end of file
diff --git a/cpp/trtorchc/main.cpp b/cpp/trtorchc/main.cpp
new file mode 100644
index 0000000000..5dab59a4ea
--- /dev/null
+++ b/cpp/trtorchc/main.cpp
@@ -0,0 +1,366 @@
+#include
+#include
+#include
+#include
+
+#ifdef linux
+#include
+#else
+#define PATH_MAX 260
+#endif
+
+#include "NvInfer.h"
+#include "third_party/args/args.hpp"
+#include "torch/torch.h"
+#include "torch/script.h"
+#include "trtorch/trtorch.h"
+
+bool checkRtol(const at::Tensor& diff, const std::vector inputs, float threshold) {
+ double maxValue = 0.0;
+ for (auto& tensor : inputs) {
+ maxValue = fmax(tensor.abs().max().item(), maxValue);
+ }
+ trtorch::logging::log(trtorch::logging::Level::kDEBUG, std::string("Max Difference: ") + std::to_string(diff.abs().max().item()));
+ trtorch::logging::log(trtorch::logging::Level::kDEBUG, std::string("Acceptable Threshold: ") + std::to_string(threshold));
+ return diff.abs().max().item() <= threshold * maxValue;
+}
+
+bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) {
+ return checkRtol(a - b, {a, b}, threshold);
+}
+
+std::vector parseSingleDim(std::string shape_str) {
+ std::vector shape;
+ std::stringstream ss;
+ for (auto c : shape_str) {
+ if (c == '(' || c == ' ') {
+ continue;
+ } else if (c == ',') {
+ int64_t dim;
+ ss >> dim;
+ shape.push_back(dim);
+ ss.clear();
+ } else if (c == ')') {
+ int64_t dim;
+ ss >> dim;
+ shape.push_back(dim);
+ ss.clear();
+ return shape;
+ } else {
+ ss << c;
+ }
+ }
+
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Shapes need dimensions delimited by comma in parentheses, \"(N,..,C,H,W)\"\n e.g \"(3,3,200,200)\"");
+ exit(1);
+ return {};
+}
+
+trtorch::ExtraInfo::InputRange parseDynamicDim(std::string shape_str) {
+ shape_str = shape_str.substr(1, shape_str.size() - 2);
+ std::vector> shape;
+ std::stringstream ss;
+
+ std::string delimiter = ";";
+
+ size_t pos = 0;
+ while ((pos = shape_str.find(delimiter)) != std::string::npos) {
+ auto token = shape_str.substr(0, pos);
+ auto range = parseSingleDim(token);
+ shape_str.erase(0, pos + delimiter.length());
+ shape.push_back(range);
+ }
+
+ auto range = parseSingleDim(shape_str);
+ shape.push_back(range);
+
+ if (shape.size() != 3) {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Dynamic shapes need three sets of dimensions delimited by semi-colons, \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\"");
+ exit(1);
+ }
+
+ return trtorch::ExtraInfo::InputRange(shape[0], shape[1], shape[2]);
+}
+
+std::string get_cwd() {
+ char buff[FILENAME_MAX]; //create string buffer to hold path
+ if (getcwd(buff, FILENAME_MAX)) {
+ std::string current_working_dir(buff);
+ return current_working_dir;
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Unable to get current directory");
+ exit(1);
+ }
+}
+
+std::string real_path(std::string path) {
+ auto abs_path = path;
+ char real_path_c[PATH_MAX];
+ char* res = realpath(abs_path.c_str(), real_path_c);
+ if (res) {
+ return std::string(real_path_c);
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, std::string("Unable to find file ") + abs_path);
+ exit(1);
+ }
+}
+
+std::string resolve_path(std::string path) {
+ auto rpath = path;
+ if (!(rpath.rfind("/", 0) == 0)) {
+ rpath = get_cwd() + '/' + rpath;
+ }
+ return rpath;
+}
+
+int main(int argc, char** argv) {
+ trtorch::logging::set_is_colored_output_on(true);
+ trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kWARNING);
+ trtorch::logging::set_logging_prefix("");
+
+
+ args::ArgumentParser parser("TRTorch is a compiler for TorchScript, it will compile and optimize TorchScript programs to run on NVIDIA GPUs using TensorRT", "");
+ args::HelpFlag help(parser, "help", "Display this help menu", {'h', "help"});
+
+ args::Group group(parser, "Verbiosity of the compiler", args::Group::Validators::AtMostOne);
+ args::Flag verbose(group, "verbose", "Dumps debugging information about the compilation process onto the console", {'v', "verbose"});
+ args::Flag warning(group, "warning", "Disables warnings generated during compilation onto the console (warnings are on by default)", {'w', "warnings"});
+ args::Flag info(group, "info", "Dumps info messages generated during compilation onto the console", {"i", "info"});
+
+ args::Flag build_debuggable_engine(parser, "build-debuggable-engine", "Creates a debuggable engine", {"build-debuggable-engine"});
+ args::Flag use_strict_types(parser, "use-strict-types", "Restrict operating type to only use set default operation precision (op_precision)", {"use-strict-types"});
+ args::Flag allow_gpu_fallback(parser, "allow-gpu-fallback", "(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA", {"allow-gpu-fallback"});
+
+ args::ValueFlag op_precision(parser, "precision", "Default operating precision for the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ] (default: float)", {'p', "default-op-precision"});
+ args::ValueFlag device_type(parser, "type", "The type of device the engine should be built for [ gpu | dla ] (default: gpu)", {'d', "device-type"});
+ args::ValueFlag engine_capability(parser, "capability", "The type of device the engine should be built for [ default | safe_gpu | safe_dla ]", {"engine-capability"});
+
+ args::ValueFlag calibration_cache_file(parser, "file_path", "Path to calibration cache file to use for post training quantization", {"calibration-cache-file"});
+ args::ValueFlag num_min_timing_iters(parser, "num_iters", "Number of minimization timing iterations used to select kernels", {"num-min-timing-iter"});
+ args::ValueFlag num_avg_timing_iters(parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
+ args::ValueFlag workspace_size(parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
+ args::ValueFlag max_batch_size(parser, "max_batch_size", "Maximum batch size (must be >= 1 to be set, 0 means not set)", {"max-batch-size"});
+ args::ValueFlag threshold(parser, "threshold", "Maximum acceptable numerical deviation from standard torchscript output (default 2e-5)", {'t', "threshold"});
+
+
+ args::Flag save_engine(parser, "save_engine", "Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path", {"save-engine"});
+ args::Positional input_path(parser, "input_file_path", "Path to input TorchScript file");
+ args::Positional output_path(parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file");
+ args::PositionalList input_shapes(parser, "input_shapes", "Sizes for inputs to engine, can either be a single size or a range defined by Min, Optimal, Max sizes, e.g. \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"");
+
+
+ try
+ {
+ parser.ParseCLI(argc, argv);
+ }
+ catch (args::Help)
+ {
+ std::cout << parser;
+ return 0;
+ }
+ catch (args::ParseError e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::cerr << parser;
+ return 1;
+ }
+ catch (args::ValidationError e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::cerr << parser;
+ return 1;
+ }
+
+ if (verbose) {
+ trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kDEBUG);
+ } else if (info) {
+ trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO);
+ } else if (warning) {
+ trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kERROR);
+ }
+
+
+ std::vector ranges;
+ for (const auto shapes : args::get(input_shapes)) {
+ if (shapes.rfind("(", 0) == 0) {
+ ranges.push_back(trtorch::ExtraInfo::InputRange(parseSingleDim(shapes)));
+ } else if (shapes.rfind("[", 0) == 0) {
+ ranges.push_back(parseDynamicDim(shapes));
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Dimensions should be specified in one of these types \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"(3,3,300,300)\" \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\"");
+ std::cerr << parser;
+ exit(1);
+ }
+ }
+
+ auto compile_settings = trtorch::ExtraInfo(ranges);
+
+ if (build_debuggable_engine) {
+ compile_settings.debug = true;
+ }
+
+ if (use_strict_types) {
+ compile_settings.strict_types = true;
+ }
+
+ if (allow_gpu_fallback) {
+ compile_settings.allow_gpu_fallback = true;
+ }
+
+ std::string calibration_cache_file_path = "";
+ if (calibration_cache_file) {
+ calibration_cache_file_path = resolve_path(args::get(calibration_cache_file));
+ }
+
+ auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file_path);
+
+ if (op_precision) {
+ auto precision = args::get(op_precision);
+ std::transform(precision.begin(), precision.end(), precision.begin(), [](unsigned char c){ return std::tolower(c); });
+ if (precision == "float" || precision == "float32" || precision == "f32") {
+ compile_settings.op_precision = torch::kF32;
+ } else if (precision == "half" || precision == "float16" || precision == "f16") {
+ compile_settings.op_precision = torch::kF16;
+ } else if (precision == "int8" || precision == "i8") {
+ compile_settings.op_precision = torch::kI8;
+ if (calibration_cache_file) {
+ compile_settings.ptq_calibrator = calibrator;
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "If targeting INT8 default operating precision with trtorchc, a calibration cache file must be provided");
+ std::cerr << parser;
+ return 1;
+ }
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid default operating precision, options are [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ]");
+ std::cerr << parser;
+ return 1;
+ }
+ }
+
+ if (device_type) {
+ auto device = args::get(device_type);
+ std::transform(device.begin(), device.end(), device.begin(), [](unsigned char c){ return std::tolower(c); });
+ if (device == "gpu") {
+ compile_settings.device = trtorch::ExtraInfo::DeviceType::kGPU;
+ } else if (device == "dla") {
+ compile_settings.device = trtorch::ExtraInfo::DeviceType::kDLA;
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid device type, options are [ gpu | dla ]");
+ std::cerr << parser;
+ return 1;
+ }
+ }
+
+ if (engine_capability) {
+ auto capability = args::get(engine_capability);
+ std::transform(capability.begin(), capability.end(), capability.begin(), [](unsigned char c){ return std::tolower(c); });
+ if (capability == "default") {
+ compile_settings.capability = trtorch::ExtraInfo::EngineCapability::kDEFAULT;
+ } else if (capability == "safe_gpu") {
+ compile_settings.capability = trtorch::ExtraInfo::EngineCapability::kSAFE_GPU;
+ } else if (capability == "safe_dla") {
+ compile_settings.capability = trtorch::ExtraInfo::EngineCapability::kSAFE_DLA;
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid engine capability, options are [ default | safe_gpu | safe_dla ]");
+ std::cerr << parser;
+ return 1;
+ }
+ }
+
+ if (num_min_timing_iters) {
+ compile_settings.num_min_timing_iters = args::get(num_min_timing_iters);
+ }
+
+ if (num_avg_timing_iters) {
+ compile_settings.num_avg_timing_iters = args::get(num_avg_timing_iters);
+ }
+
+ if (workspace_size) {
+ compile_settings.workspace_size = args::get(workspace_size);
+ }
+
+ if (max_batch_size) {
+ compile_settings.max_batch_size = args::get(max_batch_size);
+ }
+
+ auto real_input_path = resolve_path(args::get(input_path));
+ auto real_output_path = resolve_path(args::get(output_path));
+
+ torch::jit::Module mod;
+ try {
+ // Deserialize the ScriptModule from a file using torch::jit::load().
+ mod = torch::jit::load(real_input_path);
+ }
+ catch (const c10::Error& e) {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Error loading the model (path may be incorrect)");
+ std::cerr << parser;
+ return 1;
+ }
+
+ if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
+ trtorch::logging::log(trtorch::logging::Level::kERROR, "Module is not currently supported by TRTorch");
+ return 1;
+ }
+
+ if (save_engine) {
+ auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_settings);
+ std::ofstream out(real_output_path);
+ out << engine;
+ out.close();
+ } else {
+ auto trt_mod = trtorch::CompileGraph(mod, compile_settings);
+
+ if (compile_settings.op_precision == trtorch::ExtraInfo::DataType::kFloat) {
+ double threshold_val = 2e-5;
+ if (threshold) {
+ threshold_val = args::get(threshold);
+ }
+
+ std::vector jit_inputs_ivalues;
+ std::vector trt_inputs_ivalues;
+
+ for (auto i : ranges) {
+ auto in = at::randn(i.opt, {at::kCUDA});
+ jit_inputs_ivalues.push_back(in.clone());
+ trt_inputs_ivalues.push_back(in.clone());
+ }
+
+ torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
+ std::vector jit_results;
+ if (jit_results_ivalues.isTensor()) {
+ jit_results.push_back(jit_results_ivalues.toTensor());
+ } else {
+ auto results = jit_results_ivalues.toTuple()->elements();
+ for (auto r : results) {
+ jit_results.push_back(r.toTensor());
+ }
+ }
+
+
+ torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
+ std::vector trt_results;
+ if (trt_results_ivalues.isTensor()) {
+ trt_results.push_back(trt_results_ivalues.toTensor());
+ } else {
+ auto results = trt_results_ivalues.toTuple()->elements();
+ for (auto r : results) {
+ trt_results.push_back(r.toTensor());
+ }
+ }
+
+ for (size_t i = 0; i < trt_results.size(); i++) {
+ if (!almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), threshold_val)) {
+ std::ostringstream threshold_ss;
+ threshold_ss << threshold_val;
+ trtorch::logging::log(trtorch::logging::Level::kWARNING, std::string("Maximum numerical deviation for output exceeds set threshold (") + threshold_ss.str() + std::string(")"));
+ }
+ }
+ } else {
+ trtorch::logging::log(trtorch::logging::Level::kWARNING, "Due to change in operating data type, numerical precision is not checked");
+ }
+
+ trt_mod.save(real_output_path);
+ }
+
+ return 0;
+}
\ No newline at end of file
diff --git a/cpp/trtorchexec/main.cpp b/cpp/trtorchexec/main.cpp
index 2085928b6f..8b3e114e62 100644
--- a/cpp/trtorchexec/main.cpp
+++ b/cpp/trtorchexec/main.cpp
@@ -38,6 +38,7 @@ int main(int argc, const char* argv[]) {
}
mod.to(at::kCUDA);
+ mod.eval();
std::vector> dims;
for (int i = 2; i < argc; i++) {
@@ -92,7 +93,7 @@ int main(int argc, const char* argv[]) {
std::cout << "Running TRT module" << std::endl;
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
std::vector trt_results;
- if (trt_results_ivalues.isTensor()) {
+ if (trt_results_ivalues.isTensor()) {
trt_results.push_back(trt_results_ivalues.toTensor());
} else {
auto results = trt_results_ivalues.toTuple()->elements();
@@ -106,5 +107,8 @@ int main(int argc, const char* argv[]) {
}
std::cout << "Converted Engine saved to /tmp/engine_converted_from_jit.trt" << std::endl;
+
+ trt_mod.save("/tmp/ts_trt.ts");
+ std::cout << "Compiled TorchScript program saved to /tmp/ts_trt.ts" << std::endl;
std::cout << "ok\n";
}
diff --git a/docs/._index.html b/docs/._index.html
new file mode 100644
index 0000000000..e9528f4621
Binary files /dev/null and b/docs/._index.html differ
diff --git a/docs/_cpp_api/class_view_hierarchy.html b/docs/_cpp_api/class_view_hierarchy.html
index fbdc65da2f..f85121680e 100644
--- a/docs/_cpp_api/class_view_hierarchy.html
+++ b/docs/_cpp_api/class_view_hierarchy.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DataType.html b/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DataType.html
index 3bb12ab4c5..c87d34fbc9 100644
--- a/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DataType.html
+++ b/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DataType.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
@@ -1029,6 +1034,102 @@
+
+
+
+
+ Comparision operator for
+
+
+ DataType
+
+
+ .
+
+
+
+
+
+
+ Return
+
+
+
+
+ true
+
+
+
+
+ Return
+
+
+
+
+ false
+
+
+
+
+ Parameters
+
+
+
+
+
+
+
+
+ other
+
+
+ :
+
+
+
+
+
+
+
+
+
+
+
+ Comparision operator for
+
+
+ DataType
+
+
+ .
+
+
+
+
+
+
+ Return
+
+
+
+
+ true
+
+
+
+
+ Return
+
+
+
+
+ false
+
+
+
+
+ Parameters
+
+
+
+
+
+
+
+
+ other
+
+
+ :
+
+
+
+
+
+
+
diff --git a/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DeviceType.html b/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DeviceType.html
index 2448917b3b..9307df8de1 100644
--- a/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DeviceType.html
+++ b/docs/_cpp_api/classtrtorch_1_1ExtraInfo_1_1DeviceType.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8CacheCalibrator.html b/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8CacheCalibrator.html
index f597b95aab..034a7ad417 100644
--- a/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8CacheCalibrator.html
+++ b/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8CacheCalibrator.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8Calibrator.html b/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8Calibrator.html
index fa3ae5e34f..b116598b13 100644
--- a/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8Calibrator.html
+++ b/docs/_cpp_api/classtrtorch_1_1ptq_1_1Int8Calibrator.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html b/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html
index b60ed74d7c..baab8f4257 100644
--- a/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html
+++ b/docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1a20c1fbeb21757871c52299dc52351b5f.html b/docs/_cpp_api/define_macros_8h_1a20c1fbeb21757871c52299dc52351b5f.html
index 26e7ea5e09..4343e11651 100644
--- a/docs/_cpp_api/define_macros_8h_1a20c1fbeb21757871c52299dc52351b5f.html
+++ b/docs/_cpp_api/define_macros_8h_1a20c1fbeb21757871c52299dc52351b5f.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1a25ee153c325dfc7466a33cbd5c1ff055.html b/docs/_cpp_api/define_macros_8h_1a25ee153c325dfc7466a33cbd5c1ff055.html
index 86ad8f94e7..3d633d42c2 100644
--- a/docs/_cpp_api/define_macros_8h_1a25ee153c325dfc7466a33cbd5c1ff055.html
+++ b/docs/_cpp_api/define_macros_8h_1a25ee153c325dfc7466a33cbd5c1ff055.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1a48d6029a45583a06848891cb0e86f7ba.html b/docs/_cpp_api/define_macros_8h_1a48d6029a45583a06848891cb0e86f7ba.html
index 17cd191e37..e66b846828 100644
--- a/docs/_cpp_api/define_macros_8h_1a48d6029a45583a06848891cb0e86f7ba.html
+++ b/docs/_cpp_api/define_macros_8h_1a48d6029a45583a06848891cb0e86f7ba.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1a71b02dddfabe869498ad5a88e11c440f.html b/docs/_cpp_api/define_macros_8h_1a71b02dddfabe869498ad5a88e11c440f.html
index 92e0b92577..70bab59633 100644
--- a/docs/_cpp_api/define_macros_8h_1a71b02dddfabe869498ad5a88e11c440f.html
+++ b/docs/_cpp_api/define_macros_8h_1a71b02dddfabe869498ad5a88e11c440f.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1a9d31d0569348d109b1b069b972dd143e.html b/docs/_cpp_api/define_macros_8h_1a9d31d0569348d109b1b069b972dd143e.html
index f7af8b6d00..cface47f06 100644
--- a/docs/_cpp_api/define_macros_8h_1a9d31d0569348d109b1b069b972dd143e.html
+++ b/docs/_cpp_api/define_macros_8h_1a9d31d0569348d109b1b069b972dd143e.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html b/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html
index adffa0adbb..c235e0d9f1 100644
--- a/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html
+++ b/docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html
@@ -300,6 +300,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/define_macros_8h_1ae1c56ab8a40af292a9a4964651524d84.html b/docs/_cpp_api/define_macros_8h_1ae1c56ab8a40af292a9a4964651524d84.html
index a37e731f3d..71be6933d1 100644
--- a/docs/_cpp_api/define_macros_8h_1ae1c56ab8a40af292a9a4964651524d84.html
+++ b/docs/_cpp_api/define_macros_8h_1ae1c56ab8a40af292a9a4964651524d84.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/dir_cpp.html b/docs/_cpp_api/dir_cpp.html
index c9e4c2ea38..aa3f7ff7a7 100644
--- a/docs/_cpp_api/dir_cpp.html
+++ b/docs/_cpp_api/dir_cpp.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/dir_cpp_api.html b/docs/_cpp_api/dir_cpp_api.html
index 7fbc4e9dc3..347bd9aec6 100644
--- a/docs/_cpp_api/dir_cpp_api.html
+++ b/docs/_cpp_api/dir_cpp_api.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/dir_cpp_api_include.html b/docs/_cpp_api/dir_cpp_api_include.html
index 5e429a64c4..d0dfe41b75 100644
--- a/docs/_cpp_api/dir_cpp_api_include.html
+++ b/docs/_cpp_api/dir_cpp_api_include.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/dir_cpp_api_include_trtorch.html b/docs/_cpp_api/dir_cpp_api_include_trtorch.html
index 21764a1793..f62d859e63 100644
--- a/docs/_cpp_api/dir_cpp_api_include_trtorch.html
+++ b/docs/_cpp_api/dir_cpp_api_include_trtorch.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/enum_logging_8h_1a5f612ff2f783ff4fbe89d168f0d817d4.html b/docs/_cpp_api/enum_logging_8h_1a5f612ff2f783ff4fbe89d168f0d817d4.html
index 542e81728f..129e059169 100644
--- a/docs/_cpp_api/enum_logging_8h_1a5f612ff2f783ff4fbe89d168f0d817d4.html
+++ b/docs/_cpp_api/enum_logging_8h_1a5f612ff2f783ff4fbe89d168f0d817d4.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/file_cpp_api_include_trtorch_logging.h.html b/docs/_cpp_api/file_cpp_api_include_trtorch_logging.h.html
index 31065539ec..04ade4a022 100644
--- a/docs/_cpp_api/file_cpp_api_include_trtorch_logging.h.html
+++ b/docs/_cpp_api/file_cpp_api_include_trtorch_logging.h.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
@@ -666,6 +671,15 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/file_cpp_api_include_trtorch_ptq.h.html b/docs/_cpp_api/file_cpp_api_include_trtorch_ptq.h.html
index 4fa2f78fe5..2523a931e6 100644
--- a/docs/_cpp_api/file_cpp_api_include_trtorch_ptq.h.html
+++ b/docs/_cpp_api/file_cpp_api_include_trtorch_ptq.h.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
@@ -657,6 +662,22 @@
+
+
+
+
+ trtorch/logging.h
+
+
+ (
+
+
+ File logging.h
+
+
+ )
+
+
diff --git a/docs/_cpp_api/file_cpp_api_include_trtorch_trtorch.h.html b/docs/_cpp_api/file_cpp_api_include_trtorch_trtorch.h.html
index 195f6874ff..aebe61a64f 100644
--- a/docs/_cpp_api/file_cpp_api_include_trtorch_trtorch.h.html
+++ b/docs/_cpp_api/file_cpp_api_include_trtorch_trtorch.h.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/file_view_hierarchy.html b/docs/_cpp_api/file_view_hierarchy.html
index 7032b84fb8..c78cea8757 100644
--- a/docs/_cpp_api/file_view_hierarchy.html
+++ b/docs/_cpp_api/file_view_hierarchy.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_logging_8h_1a118d65b179defff7fff279eb9cd126cb.html b/docs/_cpp_api/function_logging_8h_1a118d65b179defff7fff279eb9cd126cb.html
index 0f8843cbdd..05e7dc9113 100644
--- a/docs/_cpp_api/function_logging_8h_1a118d65b179defff7fff279eb9cd126cb.html
+++ b/docs/_cpp_api/function_logging_8h_1a118d65b179defff7fff279eb9cd126cb.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_logging_8h_1a396a688110397538f8b3fb7dfdaf38bb.html b/docs/_cpp_api/function_logging_8h_1a396a688110397538f8b3fb7dfdaf38bb.html
index 114f9718ad..64a730c483 100644
--- a/docs/_cpp_api/function_logging_8h_1a396a688110397538f8b3fb7dfdaf38bb.html
+++ b/docs/_cpp_api/function_logging_8h_1a396a688110397538f8b3fb7dfdaf38bb.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_logging_8h_1a9b420280bfacc016d7e36a5704021949.html b/docs/_cpp_api/function_logging_8h_1a9b420280bfacc016d7e36a5704021949.html
index d380324c31..24130faded 100644
--- a/docs/_cpp_api/function_logging_8h_1a9b420280bfacc016d7e36a5704021949.html
+++ b/docs/_cpp_api/function_logging_8h_1a9b420280bfacc016d7e36a5704021949.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_logging_8h_1aa533955a2b908db9e5df5acdfa24715f.html b/docs/_cpp_api/function_logging_8h_1aa533955a2b908db9e5df5acdfa24715f.html
index 88d8902db6..1691a73dcf 100644
--- a/docs/_cpp_api/function_logging_8h_1aa533955a2b908db9e5df5acdfa24715f.html
+++ b/docs/_cpp_api/function_logging_8h_1aa533955a2b908db9e5df5acdfa24715f.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_logging_8h_1abc57d473f3af292551dee8b9c78373ad.html b/docs/_cpp_api/function_logging_8h_1abc57d473f3af292551dee8b9c78373ad.html
index 985f30ec38..f2768d2b4c 100644
--- a/docs/_cpp_api/function_logging_8h_1abc57d473f3af292551dee8b9c78373ad.html
+++ b/docs/_cpp_api/function_logging_8h_1abc57d473f3af292551dee8b9c78373ad.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_logging_8h_1adf5435f0dbb09c0d931a1b851847236b.html b/docs/_cpp_api/function_logging_8h_1adf5435f0dbb09c0d931a1b851847236b.html
index c5e37b48fc..f939b71ee0 100644
--- a/docs/_cpp_api/function_logging_8h_1adf5435f0dbb09c0d931a1b851847236b.html
+++ b/docs/_cpp_api/function_logging_8h_1adf5435f0dbb09c0d931a1b851847236b.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_logging_8h_1aef44b69c62af7cf2edc8875a9506641a.html b/docs/_cpp_api/function_logging_8h_1aef44b69c62af7cf2edc8875a9506641a.html
index 67a76c7ef1..0f0febae9b 100644
--- a/docs/_cpp_api/function_logging_8h_1aef44b69c62af7cf2edc8875a9506641a.html
+++ b/docs/_cpp_api/function_logging_8h_1aef44b69c62af7cf2edc8875a9506641a.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_trtorch_8h_1a2cf17d43ba9117b3b4d652744b4f0447.html b/docs/_cpp_api/function_trtorch_8h_1a2cf17d43ba9117b3b4d652744b4f0447.html
index 604ba98513..a732856cd5 100644
--- a/docs/_cpp_api/function_trtorch_8h_1a2cf17d43ba9117b3b4d652744b4f0447.html
+++ b/docs/_cpp_api/function_trtorch_8h_1a2cf17d43ba9117b3b4d652744b4f0447.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_trtorch_8h_1a4422781719d7befedb364cacd91c6247.html b/docs/_cpp_api/function_trtorch_8h_1a4422781719d7befedb364cacd91c6247.html
index 5cd9ff7479..7a4cd7291b 100644
--- a/docs/_cpp_api/function_trtorch_8h_1a4422781719d7befedb364cacd91c6247.html
+++ b/docs/_cpp_api/function_trtorch_8h_1a4422781719d7befedb364cacd91c6247.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_trtorch_8h_1a536bba54b70e44554099d23fa3d7e804.html b/docs/_cpp_api/function_trtorch_8h_1a536bba54b70e44554099d23fa3d7e804.html
index f415c4dbbb..4a3777ac7e 100644
--- a/docs/_cpp_api/function_trtorch_8h_1a536bba54b70e44554099d23fa3d7e804.html
+++ b/docs/_cpp_api/function_trtorch_8h_1a536bba54b70e44554099d23fa3d7e804.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_trtorch_8h_1a5f33b142bc2f3f2aaf462270b3ad7e31.html b/docs/_cpp_api/function_trtorch_8h_1a5f33b142bc2f3f2aaf462270b3ad7e31.html
index 98ee58c07b..b64edced5a 100644
--- a/docs/_cpp_api/function_trtorch_8h_1a5f33b142bc2f3f2aaf462270b3ad7e31.html
+++ b/docs/_cpp_api/function_trtorch_8h_1a5f33b142bc2f3f2aaf462270b3ad7e31.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_trtorch_8h_1a726f6e7091b6b7be45b5a4275b2ffb10.html b/docs/_cpp_api/function_trtorch_8h_1a726f6e7091b6b7be45b5a4275b2ffb10.html
index 11de089949..445ad9224a 100644
--- a/docs/_cpp_api/function_trtorch_8h_1a726f6e7091b6b7be45b5a4275b2ffb10.html
+++ b/docs/_cpp_api/function_trtorch_8h_1a726f6e7091b6b7be45b5a4275b2ffb10.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_trtorch_8h_1ab01696cfe08b6a5293c55935a9713c25.html b/docs/_cpp_api/function_trtorch_8h_1ab01696cfe08b6a5293c55935a9713c25.html
index a5b0c66cc7..a75b7ee703 100644
--- a/docs/_cpp_api/function_trtorch_8h_1ab01696cfe08b6a5293c55935a9713c25.html
+++ b/docs/_cpp_api/function_trtorch_8h_1ab01696cfe08b6a5293c55935a9713c25.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/function_trtorch_8h_1ae38897d1ca4438227c970029d0f76fb5.html b/docs/_cpp_api/function_trtorch_8h_1ae38897d1ca4438227c970029d0f76fb5.html
index e8d434f4f4..59793d3c32 100644
--- a/docs/_cpp_api/function_trtorch_8h_1ae38897d1ca4438227c970029d0f76fb5.html
+++ b/docs/_cpp_api/function_trtorch_8h_1ae38897d1ca4438227c970029d0f76fb5.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/namespace_trtorch.html b/docs/_cpp_api/namespace_trtorch.html
index bae0f7be6c..2a12108e0e 100644
--- a/docs/_cpp_api/namespace_trtorch.html
+++ b/docs/_cpp_api/namespace_trtorch.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/namespace_trtorch__logging.html b/docs/_cpp_api/namespace_trtorch__logging.html
index 4a625f5b9e..4582223721 100644
--- a/docs/_cpp_api/namespace_trtorch__logging.html
+++ b/docs/_cpp_api/namespace_trtorch__logging.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/namespace_trtorch__ptq.html b/docs/_cpp_api/namespace_trtorch__ptq.html
index a31b22746d..05e742a6c4 100644
--- a/docs/_cpp_api/namespace_trtorch__ptq.html
+++ b/docs/_cpp_api/namespace_trtorch__ptq.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_logging.h.html b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_logging.h.html
index 92d75d8ea7..51289e027e 100644
--- a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_logging.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_logging.h.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_macros.h.html b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_macros.h.html
index 7873a19d40..1a259290fd 100644
--- a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_macros.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_macros.h.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.html b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.html
index 9eeffec98d..a6f6a84e77 100644
--- a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
@@ -512,6 +517,8 @@
#include <iostream>
#include <sstream>
+#include "trtorch/logging.h"
+
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace nvinfer1 {
class IInt8Calibrator ;
@@ -519,9 +526,12 @@
}
namespace torch {
-namespace data {
-template < typename Example >
-class Iterator ;
+class Tensor ;
+}
+
+namespace trtorch {
+namespace ptq {
+bool get_batch_impl ( void * bindings [], const char * names [], int nbBindings , torch :: Tensor & data );
}
}
#endif //DOXYGEN_SHOULD_SKIP_THIS
@@ -535,7 +545,12 @@
using Batch = typename DataLoader :: super :: BatchType ;
public :
Int8Calibrator ( DataLoaderUniquePtr dataloader , const std :: string & cache_file_path , bool use_cache )
- : dataloader_ ( dataloader . get ()), it_ ( dataloader_ -> end ()), cache_file_path_ ( cache_file_path ), use_cache_ ( use_cache ) {}
+ : dataloader_ ( dataloader . get ()), cache_file_path_ ( cache_file_path ), use_cache_ ( use_cache ) {
+ for ( auto batch : * dataloader_ ) {
+ batched_data_ . push_back ( batch . data );
+ }
+ it_ = batched_data_ . begin ();
+ }
int getBatchSize () const override {
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
@@ -546,26 +561,15 @@
}
bool getBatch ( void * bindings [], const char * names [], int nbBindings ) override {
- // HACK: doesnt seem like the first try in the initializer list works
- if ( ! it_created_ ) {
- it_ = dataloader_ -> begin ();
- it_created_ = true ;
- }
-
- if ( it_ == dataloader_ -> end ()) {
+ if ( it_ != batched_data_ . end ()) {
+ auto status = get_batch_impl ( bindings , names , nbBindings , * it_ );
+ it_ = ++ it_ ;
+ return status ;
+ } else {
+ // Reset iterator if incase calibrator is going to be used again
+ it_ = batched_data_ . begin ();
return false ;
}
-
- auto batch = * it_ ;
-
- for ( int i = 0 ; i < nbBindings ; i ++ ) {
- auto data = batch . data ;
- data = data . to ( at :: kCUDA ). contiguous ();
- bindings [ i ] = data . data_ptr ();
- }
-
- it_ = ++ it_ ;
- return true ;
}
const void * readCalibrationCache ( size_t & length ) override {
@@ -573,18 +577,17 @@
std :: stringstream ss ;
ss << "Reading Calibration Cache from " << cache_file_path_ ;
logging :: log ( logging :: Level :: kINFO , ss . str ());
+
cache_ . clear ();
- std :: ifstream cache_file ( cache_file_path_ , std :: ios :: binary );
- cache_file >> std :: noskipws ;
- if ( cache_file . good ()) {
- std :: copy ( std :: istream_iterator < char > ( cache_file ),
- std :: istream_iterator < char > (),
- std :: back_inserter ( cache_ ));
- ss << "Cache read" ;
- logging :: log ( logging :: Level :: kDEBUG , ss . str ());
+ std :: ifstream input ( cache_file_path_ , std :: ios :: binary );
+ input >> std :: noskipws ;
+ if ( input . good ()) {
+ std :: copy ( std :: istream_iterator < char > ( input ), std :: istream_iterator < char > (),
+ std :: back_inserter ( cache_ ));
+ logging :: log ( logging :: Level :: kDEBUG , "Cache read" );
}
- cache_size_ = cache_ . size ();
- return cache_size_ ? cache_ . data () : nullptr ;
+ length = cache_ . size ();
+ return length ? cache_ . data () : nullptr ;
}
return nullptr ;
}
@@ -603,12 +606,13 @@
private :
DataLoader * dataloader_ ;
- torch :: data :: Iterator < Batch > it_ ;
const std :: string & cache_file_path_ ;
size_t cache_size_ = 0 ;
bool use_cache_ ;
std :: vector < char > cache_ ;
- bool it_created_ = false ;
+ std :: vector < torch :: Tensor > batched_data_ ;
+ std :: vector < torch :: Tensor >:: iterator it_ ;
+
};
template < typename Algorithm >
@@ -632,23 +636,17 @@
std :: stringstream ss ;
ss << "Reading Calibration Cache from " << cache_file_path_ ;
logging :: log ( logging :: Level :: kINFO , ss . str ());
+
cache_ . clear ();
- std :: ifstream cache_file ;
- cache_file . open ( cache_file_path_ , std :: ios :: in | std :: ios :: binary );
- cache_file . unsetf ( std :: ios :: skipws );
- cache_file . seekg ( 0 , std :: ios :: beg );
- cache_ . reserve ( cache_file . tellg ());
- cache_file . seekg ( 0 , std :: ios :: beg );
- if ( cache_file . good ()) {
- std :: cout << "Trying to read cache" << std :: endl ;
- std :: copy ( std :: istreambuf_iterator < char > ( cache_file ),
- std :: istreambuf_iterator < char > (),
- std :: back_inserter ( cache_ ));
- ss << "Cache read" ;
- logging :: log ( logging :: Level :: kDEBUG , ss . str ());
+ std :: ifstream input ( cache_file_path_ , std :: ios :: binary );
+ input >> std :: noskipws ;
+ if ( input . good ()) {
+ std :: copy ( std :: istream_iterator < char > ( input ), std :: istream_iterator < char > (),
+ std :: back_inserter ( cache_ ));
+ logging :: log ( logging :: Level :: kDEBUG , "Cache read" );
}
- cache_size_ = cache_ . size ();
- return cache_size_ ? cache_ . data () : nullptr ;
+ length = cache_ . size ();
+ return length ? cache_ . data () : nullptr ;
}
diff --git a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.html b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.html
index 1184e3258d..c85c92f209 100644
--- a/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.html
+++ b/docs/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
@@ -568,7 +573,9 @@ operator Value () const { return value ; }
explicit operator bool () = delete ;
constexpr bool operator == ( DataType other ) const { return value == other . value ; }
+ constexpr bool operator == ( DataType :: Value other ) const { return value == other ; }
constexpr bool operator != ( DataType other ) const { return value != other . value ; }
+ constexpr bool operator != ( DataType :: Value other ) const { return value != other ; }
private :
Value value ;
};
diff --git a/docs/_cpp_api/structtrtorch_1_1ExtraInfo.html b/docs/_cpp_api/structtrtorch_1_1ExtraInfo.html
index e5719a16e4..e98a7cedbd 100644
--- a/docs/_cpp_api/structtrtorch_1_1ExtraInfo.html
+++ b/docs/_cpp_api/structtrtorch_1_1ExtraInfo.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
@@ -955,7 +960,10 @@
DataType
- ::kFloat
+ ::
+
+ kFloat
+
@@ -1091,7 +1099,10 @@
DeviceType
- ::kGPU
+ ::
+
+ kGPU
+
@@ -1707,6 +1718,99 @@
+
+
+
+
+ Comparision operator for
+
+
+ DataType
+
+
+ .
+
+
+
+
+
+
+ Return
+
+
+
+
+ true
+
+
+
+
+ Return
+
+
+
+
+ false
+
+
+
+
+ Parameters
+
+
+
+
+
+
+
+
+ other
+
+
+ :
+
+
+
+
+
+
+
+
+
+
+
+ Comparision operator for
+
+
+ DataType
+
+
+ .
+
+
+
+
+
+
+ Return
+
+
+
+
+ true
+
+
+
+
+ Return
+
+
+
+
+ false
+
+
+
+
+ Parameters
+
+
+
+
+
+
+
+
+ other
+
+
+ :
+
+
+
+
+
+
+
diff --git a/docs/_cpp_api/structtrtorch_1_1ExtraInfo_1_1InputRange.html b/docs/_cpp_api/structtrtorch_1_1ExtraInfo_1_1InputRange.html
index 8b58f44947..569b00db96 100644
--- a/docs/_cpp_api/structtrtorch_1_1ExtraInfo_1_1InputRange.html
+++ b/docs/_cpp_api/structtrtorch_1_1ExtraInfo_1_1InputRange.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/trtorch_cpp.html b/docs/_cpp_api/trtorch_cpp.html
index f23f54243e..44e285ede8 100644
--- a/docs/_cpp_api/trtorch_cpp.html
+++ b/docs/_cpp_api/trtorch_cpp.html
@@ -296,6 +296,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/unabridged_api.html b/docs/_cpp_api/unabridged_api.html
index f862c086d1..96a93493a0 100644
--- a/docs/_cpp_api/unabridged_api.html
+++ b/docs/_cpp_api/unabridged_api.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_cpp_api/unabridged_orphan.html b/docs/_cpp_api/unabridged_orphan.html
index cc96ebf58c..554cd70624 100644
--- a/docs/_cpp_api/unabridged_orphan.html
+++ b/docs/_cpp_api/unabridged_orphan.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_logging.h.rst.txt b/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_logging.h.rst.txt
index bec2619937..1ca71e13ce 100644
--- a/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_logging.h.rst.txt
+++ b/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_logging.h.rst.txt
@@ -39,6 +39,8 @@ Included By
-----------
+- :ref:`file_cpp_api_include_trtorch_ptq.h`
+
- :ref:`file_cpp_api_include_trtorch_trtorch.h`
diff --git a/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_ptq.h.rst.txt b/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_ptq.h.rst.txt
index ff8e4dacc1..a5f33139f7 100644
--- a/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_ptq.h.rst.txt
+++ b/docs/_sources/_cpp_api/file_cpp_api_include_trtorch_ptq.h.rst.txt
@@ -37,6 +37,8 @@ Includes
- ``string``
+- ``trtorch/logging.h`` (:ref:`file_cpp_api_include_trtorch_logging.h`)
+
- ``vector``
diff --git a/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.rst.txt b/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.rst.txt
index 93cefc6c66..6d02e502f4 100644
--- a/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.rst.txt
+++ b/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_ptq.h.rst.txt
@@ -18,6 +18,8 @@ Program Listing for File ptq.h
#include
#include
+ #include "trtorch/logging.h"
+
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace nvinfer1 {
class IInt8Calibrator;
@@ -25,9 +27,12 @@ Program Listing for File ptq.h
}
namespace torch {
- namespace data {
- template
- class Iterator;
+ class Tensor;
+ }
+
+ namespace trtorch {
+ namespace ptq {
+ bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data);
}
}
#endif //DOXYGEN_SHOULD_SKIP_THIS
@@ -41,7 +46,12 @@ Program Listing for File ptq.h
using Batch = typename DataLoader::super::BatchType;
public:
Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
- : dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}
+ : dataloader_(dataloader.get()), cache_file_path_(cache_file_path), use_cache_(use_cache) {
+ for (auto batch : *dataloader_) {
+ batched_data_.push_back(batch.data);
+ }
+ it_ = batched_data_.begin();
+ }
int getBatchSize() const override {
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
@@ -52,26 +62,15 @@ Program Listing for File ptq.h
}
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
- // HACK: doesnt seem like the first try in the initializer list works
- if (! it_created_) {
- it_ = dataloader_->begin();
- it_created_ = true;
- }
-
- if (it_ == dataloader_->end()) {
+ if (it_ != batched_data_.end()) {
+ auto status = get_batch_impl(bindings, names, nbBindings, *it_);
+ it_ = ++it_;
+ return status;
+ } else {
+ // Reset iterator if incase calibrator is going to be used again
+ it_ = batched_data_.begin();
return false;
}
-
- auto batch = *it_;
-
- for (int i = 0; i < nbBindings; i++) {
- auto data = batch.data;
- data = data.to(at::kCUDA).contiguous();
- bindings[i] = data.data_ptr();
- }
-
- it_ = ++it_;
- return true;
}
const void* readCalibrationCache(size_t& length) override {
@@ -79,18 +78,17 @@ Program Listing for File ptq.h
std::stringstream ss;
ss << "Reading Calibration Cache from " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
+
cache_.clear();
- std::ifstream cache_file(cache_file_path_, std::ios::binary);
- cache_file >> std::noskipws;
- if (cache_file.good()) {
- std::copy(std::istream_iterator(cache_file),
- std::istream_iterator(),
- std::back_inserter(cache_));
- ss << "Cache read";
- logging::log(logging::Level::kDEBUG, ss.str());
+ std::ifstream input(cache_file_path_, std::ios::binary);
+ input >> std::noskipws;
+ if (input.good()) {
+ std::copy(std::istream_iterator(input), std::istream_iterator(),
+ std::back_inserter(cache_));
+ logging::log(logging::Level::kDEBUG, "Cache read");
}
- cache_size_ = cache_.size();
- return cache_size_ ? cache_.data() : nullptr;
+ length = cache_.size();
+ return length ? cache_.data() : nullptr;
}
return nullptr;
}
@@ -109,12 +107,13 @@ Program Listing for File ptq.h
private:
DataLoader* dataloader_;
- torch::data::Iterator it_;
const std::string& cache_file_path_;
size_t cache_size_ = 0;
bool use_cache_;
std::vector cache_;
- bool it_created_ = false;
+ std::vector batched_data_;
+ std::vector::iterator it_;
+
};
template
@@ -138,23 +137,17 @@ Program Listing for File ptq.h
std::stringstream ss;
ss << "Reading Calibration Cache from " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
+
cache_.clear();
- std::ifstream cache_file;
- cache_file.open(cache_file_path_, std::ios::in | std::ios::binary);
- cache_file.unsetf(std::ios::skipws);
- cache_file.seekg(0, std::ios::beg);
- cache_.reserve(cache_file.tellg());
- cache_file.seekg(0, std::ios::beg);
- if (cache_file.good()) {
- std::cout << "Trying to read cache" << std::endl;
- std::copy(std::istreambuf_iterator(cache_file),
- std::istreambuf_iterator(),
- std::back_inserter(cache_));
- ss << "Cache read";
- logging::log(logging::Level::kDEBUG, ss.str());
+ std::ifstream input(cache_file_path_, std::ios::binary);
+ input >> std::noskipws;
+ if (input.good()) {
+ std::copy(std::istream_iterator(input), std::istream_iterator(),
+ std::back_inserter(cache_));
+ logging::log(logging::Level::kDEBUG, "Cache read");
}
- cache_size_ = cache_.size();
- return cache_size_ ? cache_.data() : nullptr;
+ length = cache_.size();
+ return length ? cache_.data() : nullptr;
}
diff --git a/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.rst.txt b/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.rst.txt
index 14d26d87b9..fa356d80b0 100644
--- a/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.rst.txt
+++ b/docs/_sources/_cpp_api/program_listing_file_cpp_api_include_trtorch_trtorch.h.rst.txt
@@ -74,7 +74,9 @@ Program Listing for File trtorch.h
operator Value() const { return value; }
explicit operator bool() = delete;
constexpr bool operator==(DataType other) const { return value == other.value; }
+ constexpr bool operator==(DataType::Value other) const { return value == other; }
constexpr bool operator!=(DataType other) const { return value != other.value; }
+ constexpr bool operator!=(DataType::Value other) const { return value != other; }
private:
Value value;
};
diff --git a/docs/_sources/index.rst.txt b/docs/_sources/index.rst.txt
index 5255135f58..45a1610b49 100644
--- a/docs/_sources/index.rst.txt
+++ b/docs/_sources/index.rst.txt
@@ -23,15 +23,18 @@ Getting Started
* :ref:`installation`
* :ref:`getting_started`
* :ref:`ptq`
+* :ref:`trtorchc`
+
.. toctree::
:caption: Getting Started
- :maxdepth: 2
+ :maxdepth: 1
:hidden:
tutorials/installation
tutorials/getting_started
tutorials/ptq
+ tutorials/trtorchc
Contributor Documentation
--------------------------------
diff --git a/docs/_sources/tutorials/getting_started.rst.txt b/docs/_sources/tutorials/getting_started.rst.txt
index 0d133a7eab..45c08b8637 100644
--- a/docs/_sources/tutorials/getting_started.rst.txt
+++ b/docs/_sources/tutorials/getting_started.rst.txt
@@ -130,7 +130,8 @@ To compile your TorchScript module with TRTorch, all you need to do is provide t
to TRTorch and you will be returned an optimized TorchScript module to run or add into another PyTorch module. The
only required setting is the input size or input range which is defined as a list of either list types like ``lists``, ``tuples``
or PyTorch ``size`` objects or dictionaries of minimum, optimial and maximum sizes. You can also specify settings such as
-operating precision for the engine or target device.
+operating precision for the engine or target device. After compilation you can save the module just like any other module
+to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import ``trtorch``.
.. code-block:: python
@@ -152,6 +153,17 @@ operating precision for the engine or target device.
input_data = input_data.half()
result = trt_ts_module(input_data)
+ torch.jit.save(trt_ts_module, "trt_ts_module.ts")
+
+.. code-block:: python
+
+ # Deployment application
+ import torch
+ import trtorch
+
+ trt_ts_module = torch.jit.load("trt_ts_module.ts")
+ input_data = input_data.half()
+ result = trt_ts_module(input_data)
.. _ts_in_cc:
@@ -251,7 +263,35 @@ We can also set settings like operating precision to run in FP16.
auto trt_mod = trtorch::CompileGraph(mod, info);
auto out = trt_mod.forward({in});
-And now we are running the module in FP16 precision.
+And now we are running the module in FP16 precision. You can then save the module to load later.
+
+.. code-block:: c++
+
+ trt_mod.save("")
+
+TRTorch compiled TorchScript modules are loaded in the same way as normal TorchScript module. Make sure your deployment application is linked against ``libtrtorch.so``
+
+.. code-block:: c++
+
+ #include "torch/script.h"
+ #include "trtorch/trtorch.h"
+
+ int main(int argc, const char* argv[]) {
+ torch::jit::Module module;
+ try {
+ // Deserialize the ScriptModule from a file using torch::jit::load().
+ module = torch::jit::load("");
+ }
+ catch (const c10::Error& e) {
+ std::cerr << "error loading the model\n";
+ return -1;
+ }
+
+ torch::Tensor in = torch::randn({1, 1, 32, 32}, torch::kCUDA);
+ auto out = mod.forward(in);
+
+ std::cout << "ok\n";
+ }
If you want to save the engine produced by TRTorch to use in a TensorRT application you can use the ``ConvertGraphToTRTEngine`` API.
diff --git a/docs/_sources/tutorials/trtorchc.rst.txt b/docs/_sources/tutorials/trtorchc.rst.txt
new file mode 100644
index 0000000000..5561ee86ed
--- /dev/null
+++ b/docs/_sources/tutorials/trtorchc.rst.txt
@@ -0,0 +1,91 @@
+.. _trtorchc:
+
+trtorchc
+=================================
+
+``trtorchc`` is a CLI application for using the TRTorch compiler. It serves as an easy way to compile a
+TorchScript Module with TRTorch from the command-line to quickly check support or as part of
+a deployment pipeline. All basic features of the compiler are supported including post training
+quantization (though you must already have a calibration cache file to use the PTQ feature). The compiler can
+output two formats, either a TorchScript program with the TensorRT engine embedded or
+the TensorRT engine itself as a PLAN file.
+
+All that is required to run the program after compilation is for C++ linking against ``libtrtorch.so``
+or in Python importing the trtorch package. All other aspects of using compiled modules are identical
+to standard TorchScript. Load with ``torch.jit.load()`` and run like you would run any other module.
+
+.. code-block:: txt
+
+ trtorchc [input_file_path] [output_file_path]
+ [input_shapes...] {OPTIONS}
+
+ TRTorch is a compiler for TorchScript, it will compile and optimize
+ TorchScript programs to run on NVIDIA GPUs using TensorRT
+
+ OPTIONS:
+
+ -h, --help Display this help menu
+ Verbiosity of the compiler
+ -v, --verbose Dumps debugging information about the
+ compilation process onto the console
+ -w, --warnings Disables warnings generated during
+ compilation onto the console (warnings
+ are on by default)
+ --info Dumps info messages generated during
+ compilation onto the console
+ --build-debuggable-engine Creates a debuggable engine
+ --use-strict-types Restrict operating type to only use set
+ default operation precision
+ (op_precision)
+ --allow-gpu-fallback (Only used when targeting DLA
+ (device-type)) Lets engine run layers on
+ GPU if they are not supported on DLA
+ -p[precision],
+ --default-op-precision=[precision]
+ Default operating precision for the
+ engine (Int8 requires a
+ calibration-cache argument) [ float |
+ float32 | f32 | half | float16 | f16 |
+ int8 | i8 ] (default: float)
+ -d[type], --device-type=[type] The type of device the engine should be
+ built for [ gpu | dla ] (default: gpu)
+ --engine-capability=[capability] The type of device the engine should be
+ built for [ default | safe_gpu |
+ safe_dla ]
+ --calibration-cache-file=[file_path]
+ Path to calibration cache file to use
+ for post training quantization
+ --num-min-timing-iter=[num_iters] Number of minimization timing iterations
+ used to select kernels
+ --num-avg-timing-iters=[num_iters]
+ Number of averaging timing iterations
+ used to select kernels
+ --workspace-size=[workspace_size] Maximum size of workspace given to
+ TensorRT
+ --max-batch-size=[max_batch_size] Maximum batch size (must be >= 1 to be
+ set, 0 means not set)
+ -t[threshold],
+ --threshold=[threshold] Maximum acceptable numerical deviation
+ from standard torchscript output
+ (default 2e-5)
+ --save-engine Instead of compiling a full a
+ TorchScript program, save the created
+ engine to the path specified as the
+ output path
+ input_file_path Path to input TorchScript file
+ output_file_path Path for compiled TorchScript (or
+ TensorRT engine) file
+ input_shapes... Sizes for inputs to engine, can either
+ be a single size or a range defined by
+ Min, Optimal, Max sizes, e.g.
+ "(N,..,C,H,W)"
+ "[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]"
+ "--" can be used to terminate flag options and force all following
+ arguments to be treated as positional options
+
+
+e.g.
+
+.. code-block:: txt
+
+ trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16
diff --git a/docs/contributors/conversion.html b/docs/contributors/conversion.html
index 078214b750..3249c81d4f 100644
--- a/docs/contributors/conversion.html
+++ b/docs/contributors/conversion.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/contributors/execution.html b/docs/contributors/execution.html
index 9b77950c3b..b80c618615 100644
--- a/docs/contributors/execution.html
+++ b/docs/contributors/execution.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/contributors/lowering.html b/docs/contributors/lowering.html
index 9d3728b25d..ee974b61a2 100644
--- a/docs/contributors/lowering.html
+++ b/docs/contributors/lowering.html
@@ -301,6 +301,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/contributors/phases.html b/docs/contributors/phases.html
index f5e355f36b..1583f5200b 100644
--- a/docs/contributors/phases.html
+++ b/docs/contributors/phases.html
@@ -294,6 +294,11 @@
+
+
+ trtorchc
+
+
diff --git a/docs/contributors/system_overview.html b/docs/contributors/system_overview.html
index 53e8c4bbdd..05f7f760fc 100644
--- a/docs/contributors/system_overview.html
+++ b/docs/contributors/system_overview.html
@@ -57,7 +57,7 @@
-
+
@@ -296,6 +296,11 @@
+
+
+ trtorchc
+
+
@@ -695,7 +700,7 @@