Skip to content

Commit 51a2043

Browse files
authored
Merge pull request #194 from NVIDIA/to_backend_api
Implementation of the PyTorch Backend API
2 parents 2809f2b + d150930 commit 51a2043

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+962
-444
lines changed

.bazelrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ build --cxxopt='-std=c++14'
2929
build:python --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3030
build:python --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3131
build:python --define=abi=pre_cxx11_abi
32+
build:python --define=target_lang=python
3233

3334
build:pre_cxx11_abi --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3435
build:pre_cxx11_abi --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"

.github/pr-labels.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
"component: evaluators":
1717
- core/conversion/evaluators/**/*
1818

19-
"component: execution":
20-
- core/execution/**/*
19+
"component: runtime":
20+
- core/runtime/**/*
2121

2222
"component: lowering":
2323
- core/lowering/**/*
@@ -32,4 +32,4 @@
3232
"documentation":
3333
- docs/**/*
3434
- docsrc/**/*
35-
35+

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pkg_tar(
1919
"//core/conversion/tensorcontainer:include",
2020
"//core/conversion/evaluators:include",
2121
"//core/conversion/converters/impl/plugins:include",
22-
"//core/execution:include",
22+
"//core/runtime:include",
2323
"//core/lowering:include",
2424
"//core/lowering/passes:include",
2525
"//core/util:include",

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ More Information / System Architecture:
1818
#include "trtorch/trtorch.h"
1919

2020
...
21-
auto compile_settings = trtorch::ExtraInfo(dims);
21+
auto compile_settings = trtorch::CompileSpec(dims);
2222
// FP16 execution
2323
compile_settings.op_precision = torch::kFloat;
2424
// Compile module
@@ -54,7 +54,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
5454
```
5555

5656
> Notes on running in lower precisions:
57-
> - Set precision with extra_info.op_precision
57+
> - Set precision with compile_spec.op_precision
5858
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
5959
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
6060

core/BUILD

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ config_setting(
77
}
88
)
99

10+
config_setting(
11+
name = "python_core",
12+
values = {
13+
"define": "target_lang=python"
14+
}
15+
)
16+
1017
cc_library(
1118
name = "core",
1219
hdrs = [
@@ -17,7 +24,7 @@ cc_library(
1724
],
1825
deps = [
1926
"//core/conversion",
20-
"//core/execution",
27+
"//core/runtime",
2128
"//core/lowering",
2229
"//core/util/logging",
2330
"@tensorrt//:nvinfer"
@@ -28,11 +35,13 @@ cc_library(
2835
alwayslink=True,
2936
)
3037

31-
3238
load("@rules_pkg//:pkg.bzl", "pkg_tar")
3339

3440
pkg_tar(
3541
name = "include",
3642
package_dir = "core/",
37-
srcs = ["compiler.h"],
43+
srcs = [
44+
"backend.h",
45+
"compiler.h",
46+
],
3847
)

core/compiler.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
#include "core/lowering/lowering.h"
2222
#include "core/conversion/conversion.h"
23-
#include "core/execution/execution.h"
23+
#include "core/runtime/runtime.h"
2424

2525
namespace trtorch {
2626
namespace core {
@@ -42,15 +42,15 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str
4242

4343

4444
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
45-
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine);
45+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
4646
// Get required metadata about the engine out
4747
auto num_io = engine_ptr->num_io;
4848
auto name = engine_ptr->name;
4949

5050
// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
5151
mod.register_attribute(
5252
name,
53-
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),
53+
c10::getCustomClassType<c10::intrusive_ptr<runtime::TRTEngine>>(),
5454
c10::IValue(std::move(engine_ptr)),
5555
false
5656
);
@@ -125,7 +125,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
125125

126126
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
127127
std::string method_name,
128-
ExtraInfo cfg) {
128+
CompileSpec cfg) {
129129

130130
// Go through Lowering to simplify graph and extract weight parameters
131131
auto graph_and_parameters = lowering::Lower(mod, method_name);
@@ -137,12 +137,12 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
137137

138138
LOG_INFO(*g << "(CompileGraph)\n");
139139

140-
auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
140+
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
141141
return std::move(engine);
142142
}
143143

144144
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
145-
ExtraInfo cfg) {
145+
CompileSpec cfg) {
146146
// TODO: Should be doing a functional transform but need PR #31978
147147
// [jit] More robust mangling
148148
//torch::jit::script::Module new_mod = mod.clone();

core/compiler.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
namespace trtorch {
88
namespace core {
99

10-
struct ExtraInfo {
11-
ExtraInfo(std::vector<conversion::InputRange> input_ranges)
10+
struct CompileSpec {
11+
CompileSpec(std::vector<conversion::InputRange> input_ranges)
1212
: convert_info(std::move(input_ranges)) {}
1313
conversion::ConversionInfo convert_info;
1414
};
1515

1616
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
1717

1818
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
19-
std::string method_name, ExtraInfo cfg);
19+
std::string method_name, CompileSpec cfg);
2020

21-
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);
21+
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2222

2323
} // namespace core
2424
} // namespace trtorch

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5555
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
5656
}
5757
input_type = nvinfer1::DataType::kFLOAT;
58-
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
58+
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
5959
cfg->setInt8Calibrator(settings.calibrator);
6060
break;
6161
case nvinfer1::DataType::kFLOAT:

core/execution/BUILD renamed to core/runtime/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ config_setting(
88
)
99

1010
cc_library(
11-
name = "execution",
11+
name = "runtime",
1212
hdrs = [
13-
"execution.h",
13+
"runtime.h",
1414
],
1515
srcs = [
1616
"TRTEngine.cpp",
@@ -30,6 +30,6 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
3030

3131
pkg_tar(
3232
name = "include",
33-
package_dir = "core/execution/",
34-
srcs = ["execution.h"],
33+
package_dir = "core/runtime/",
34+
srcs = ["runtime.h"],
3535
)

core/execution/TRTEngine.cpp renamed to core/runtime/TRTEngine.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
#include "torch/csrc/jit/frontend/function_schema_parser.h"
55

66
#include "core/util/prelude.h"
7-
#include "core/execution/execution.h"
7+
#include "core/runtime/runtime.h"
88

99
namespace trtorch {
1010
namespace core {
11-
namespace execution {
11+
namespace runtime {
1212

1313
std::string slugify(std::string s) {
1414
std::replace(s.begin(), s.end(), '.', '_');
@@ -81,6 +81,7 @@ TRTEngine::~TRTEngine() {
8181
// return c10::List<at::Tensor>(output_vec);
8282
// }
8383

84+
namespace {
8485
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("tensorrt", "Engine")
8586
.def(torch::init<std::string>())
8687
// TODO: .def("__call__", &TRTEngine::Run)
@@ -94,7 +95,7 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("te
9495
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
9596
}
9697
);
97-
98-
} // namespace execution
98+
} // namespace
99+
} // namespace runtime
99100
} // namespace core
100101
} // namespace trtorch

core/execution/register_trt_op.cpp renamed to core/runtime/register_trt_op.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
#include "torch/csrc/jit/runtime/custom_operator.h"
55

66
#include "core/util/prelude.h"
7-
#include "core/execution/execution.h"
7+
#include "core/runtime/runtime.h"
88

99
namespace trtorch {
1010
namespace core {
11-
namespace execution {
11+
namespace runtime {
1212

1313
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
1414
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
@@ -30,7 +30,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
3030
gpu_handles.push_back(contig_inputs.back().data_ptr());
3131
}
3232

33-
TRTORCH_CHECK(compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (execution.RunCudaEngine)");
33+
TRTORCH_CHECK(compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");
3434

3535
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
3636
for (size_t o = inputs.size(); o < (compiled_engine->num_io.first + compiled_engine->num_io.second); o++) {
@@ -53,6 +53,6 @@ TORCH_LIBRARY(tensorrt, m) {
5353
m.def("execute_engine", execute_engine);
5454
}
5555

56-
} // namespace execution
56+
} // namespace runtime
5757
} // namespace core
5858
} // namespace trtorch

core/execution/execution.h renamed to core/runtime/runtime.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
namespace trtorch {
1010
namespace core {
11-
namespace execution {
11+
namespace runtime {
1212

1313
using EngineID = int64_t;
1414

@@ -35,6 +35,6 @@ struct TRTEngine : torch::CustomClassHolder {
3535

3636
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
3737

38-
} // namespace execution
38+
} // namespace runtime
3939
} // namespace core
4040
} // namespace trtorch

cpp/api/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cc_library(
99
"include/trtorch/ptq.h"
1010
],
1111
srcs = [
12-
"src/extra_info.cpp",
12+
"src/compile_spec.cpp",
1313
"src/logging.cpp",
1414
"src/trtorch.cpp",
1515
"src/ptq.cpp"

cpp/api/README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace trtorch {
3131
* Settings data structure for TRTorch compilation
3232
*
3333
*/
34-
struct TRTORCH_API ExtraInfo {
34+
struct TRTORCH_API CompileSpec {
3535
/**
3636
* @brief A struct to hold an input range (used by TensorRT Optimization profile)
3737
*
@@ -132,10 +132,10 @@ struct TRTORCH_API ExtraInfo {
132132
kSAFE_DLA,
133133
};
134134

135-
ExtraInfo(std::vector<InputRange> input_ranges)
135+
CompileSpec(std::vector<InputRange> input_ranges)
136136
: input_ranges(std::move(input_ranges)) {}
137-
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
138-
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
137+
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
138+
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
139139

140140
// Defaults should reflect TensorRT defaults for BuilderConfig
141141

@@ -236,27 +236,27 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& mo
236236
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
237237
*
238238
* @param module: torch::jit::script::Module - Existing TorchScript module
239-
* @param info: trtorch::ExtraInfo - Compilation settings
239+
* @param info: trtorch::CompileSpec - Compilation settings
240240
*
241241
* Takes a existing TorchScript module and a set of settings to configure the compiler
242242
* and will convert methods to JIT Graphs which call equivalent TensorRT engines
243243
*
244244
* Converts specifically the forward method of a TorchScript Module
245245
*/
246-
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info);
246+
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec info);
247247

248248
/**
249249
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
250250
*
251251
* @param module: torch::jit::script::Module - Existing TorchScript module
252252
* @param method_name: std::string - Name of method to compile
253-
* @param info: trtorch::ExtraInfo - Compilation settings
253+
* @param info: trtorch::CompileSpec - Compilation settings
254254
*
255255
* Takes a existing TorchScript module and a set of settings to configure the compiler
256256
* and will convert selected method to a serialized TensorRT engine which can be run with
257257
* TensorRT
258258
*/
259-
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
259+
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, CompileSpec info);
260260

261261
namespace ptq {
262262
/**

cpp/api/include/trtorch/ptq.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class Int8Calibrator : Algorithm {
145145
/**
146146
* @brief operator to cast to nvinfer1::IInt8Calibrator*
147147
*
148-
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
148+
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
149149
*
150150
* @return nvinfer1::IInt8Calibrator*
151151
*/
@@ -259,7 +259,7 @@ class Int8CacheCalibrator : Algorithm {
259259
/**
260260
* @brief operator to cast to nvinfer1::IInt8Calibrator*
261261
*
262-
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
262+
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
263263
*
264264
* @return nvinfer1::IInt8Calibrator*
265265
*/

0 commit comments

Comments
 (0)