Skip to content

Commit fb1a299

Browse files
committed
refactor(//core/partitioning): Reorganizing partitioning deps
Reorganizing the partitioning dependencies so that there is a clearer relationship between major compiler modules Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6e96289 commit fb1a299

22 files changed

+249
-172
lines changed

core/compiler.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
214214

215215
// segment the graph and convert segmented TensorRT block
216216
auto segmented_blocks =
217-
partitioning::Partition(g, convert_cfg.input_ranges, convert_cfg.engine_settings.torch_fallback);
217+
partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info);
218218
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
219219
return mod;
220220
}
@@ -223,9 +223,9 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
223223
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
224224
for (auto& seg_block : segmented_blocks) {
225225
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
226-
std::vector<conversion::InputRange> input_ranges;
226+
std::vector<ir::InputRange> input_ranges;
227227
for (auto& shape : seg_block.in_shape()) {
228-
input_ranges.push_back(conversion::InputRange(util::toVec(shape)));
228+
input_ranges.push_back(ir::InputRange(util::toVec(shape)));
229229
}
230230
// update the input ranges for each segments
231231
convert_cfg.input_ranges = input_ranges;
@@ -258,7 +258,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
258258

259259
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
260260
// TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
261-
if (cfg.convert_info.engine_settings.torch_fallback.enabled) {
261+
if (cfg.partition_info.enabled) {
262262
return CompileGraphWithFallback(mod, cfg);
263263
}
264264
// TODO: Should be doing a functional transform but need PR #31978

core/compiler.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
#include <cuda_runtime.h>
44
#include <vector>
55
#include "core/conversion/conversion.h"
6+
#include "core/ir/ir.h"
7+
#include "core/partitioning/partitioning.h"
68
#include "torch/csrc/jit/api/module.h"
79

810
namespace trtorch {
911
namespace core {
1012

1113
struct CompileSpec {
12-
CompileSpec(std::vector<conversion::InputRange> input_ranges) : convert_info(std::move(input_ranges)) {}
14+
CompileSpec(std::vector<ir::InputRange> input_ranges) : convert_info(std::move(input_ranges)) {}
1315
conversion::ConversionInfo convert_info;
16+
partitioning::PartitionInfo partition_info;
1417
};
1518

1619
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

core/conversion/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ cc_library(
2323
"//core/conversion/conversionctx",
2424
"//core/conversion/converters",
2525
"//core/conversion/evaluators",
26-
"//core/util:prelude"
26+
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/InterfaceTypes.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,55 +23,6 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect
2323
return std::move(named_params);
2424
}
2525

26-
InputRange::InputRange(std::vector<int64_t> d) {
27-
if (d.size() > 5) {
28-
LOG_WARNING("Verify that this dim size is accepted");
29-
}
30-
31-
opt = util::toDims(d);
32-
min = util::toDims(d);
33-
max = util::toDims(d);
34-
input_shape = util::toDims(d);
35-
input_is_dynamic = false;
36-
}
37-
38-
InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape) {
39-
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
40-
LOG_WARNING("Verify that this dim size is accepted");
41-
}
42-
43-
std::set<size_t> sizes;
44-
sizes.insert(min_shape.size());
45-
sizes.insert(opt_shape.size());
46-
sizes.insert(max_shape.size());
47-
48-
if (sizes.size() != 1) {
49-
LOG_ERROR(
50-
"Expected all input sizes have the same dimensions, but found dimensions: min("
51-
<< min_shape.size() << "), opt(" << opt_shape.size() << "), max(" << max_shape.size() << ")");
52-
}
53-
54-
min = util::toDims(min_shape);
55-
opt = util::toDims(opt_shape);
56-
max = util::toDims(max_shape);
57-
58-
std::vector<int64_t> dyn_shape;
59-
for (size_t i = 0; i < opt_shape.size(); i++) {
60-
std::set<uint64_t> dim;
61-
dim.insert(min_shape[i]);
62-
dim.insert(opt_shape[i]);
63-
dim.insert(max_shape[i]);
64-
if (dim.size() != 1) {
65-
dyn_shape.push_back(-1);
66-
input_is_dynamic = true;
67-
} else {
68-
dyn_shape.push_back(opt_shape[i]);
69-
}
70-
}
71-
72-
input_shape = util::toDims(dyn_shape);
73-
}
74-
7526
} // namespace conversion
7627
} // namespace core
7728
} // namespace trtorch

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
118118
<< "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
119119
}
120120

121-
void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<InputRange>& input_dims) {
121+
void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::InputRange>& input_dims) {
122122
std::vector<const torch::jit::Value*> input_tensors;
123123
for (auto in : inputs) {
124124
// Disregarding inputs that are not tensors

core/conversion/conversion.h

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,17 @@
44

55
#include "NvInfer.h"
66
#include "core/conversion/conversionctx/ConversionCtx.h"
7+
#include "core/ir/ir.h"
78
#include "torch/csrc/jit/ir/ir.h"
89

910
namespace trtorch {
1011
namespace core {
1112
namespace conversion {
1213

13-
struct InputRange {
14-
nvinfer1::Dims min;
15-
nvinfer1::Dims max;
16-
nvinfer1::Dims opt;
17-
nvinfer1::Dims input_shape;
18-
bool input_is_dynamic = false;
19-
// Should we restrict to unsigned?
20-
InputRange(std::vector<int64_t> d);
21-
InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
22-
};
23-
2414
struct ConversionInfo {
25-
std::vector<InputRange> input_ranges;
15+
std::vector<ir::InputRange> input_ranges;
2616
BuilderSettings engine_settings;
27-
ConversionInfo(std::vector<InputRange> input_ranges)
17+
ConversionInfo(std::vector<ir::InputRange> input_ranges)
2818
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
2919
};
3020

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,6 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
3636
}
3737
os << "\n Engine Capability: " << s.capability \
3838
<< "\n Calibrator Created: " << (s.calibrator != nullptr);
39-
40-
os << "\n Torch Fallback: " << s.torch_fallback.enabled;
41-
if (s.torch_fallback.enabled) {
42-
os << "\n Fallback Min Block Size: " << s.torch_fallback.min_block_size;
43-
if (!s.torch_fallback.forced_fallback_operators.empty()) {
44-
os << "\n Forced Fallback Operators:";
45-
for (auto it = s.torch_fallback.forced_fallback_operators.begin(); it != s.torch_fallback.forced_fallback_operators.end(); ++it) {
46-
os << " " << *it;
47-
}
48-
}
49-
}
5039
return os;
5140
}
5241
// clang-format on

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@ struct Device {
2222
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
2323
};
2424

25-
struct TorchFallback {
26-
bool enabled = false;
27-
uint64_t min_block_size = 1;
28-
std::vector<std::string> forced_fallback_operators;
29-
};
30-
3125
struct BuilderSettings {
3226
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
3327
bool disable_tf32 = false;
@@ -36,7 +30,6 @@ struct BuilderSettings {
3630
bool strict_types = false;
3731
bool truncate_long_and_double = false;
3832
Device device;
39-
TorchFallback torch_fallback;
4033
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
4134
nvinfer1::IInt8Calibrator* calibrator = nullptr;
4235
uint64_t num_min_timing_iters = 2;

core/ir/BUILD

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
10+
cc_library(
11+
name = "ir",
12+
hdrs = [
13+
"ir.h"
14+
],
15+
srcs = [
16+
"InputRange.cpp",
17+
],
18+
deps = [
19+
"@tensorrt//:nvinfer",
20+
"//core/util:prelude",
21+
] + select({
22+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
23+
"//conditions:default": ["@libtorch//:libtorch"],
24+
}),
25+
)
26+
27+
load("@rules_pkg//:pkg.bzl", "pkg_tar")
28+
29+
pkg_tar(
30+
name = "include",
31+
package_dir = "core/ir/",
32+
srcs = [
33+
"ir.h",
34+
],
35+
)

core/ir/InputRange.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "core/ir/ir.h"
2+
#include "core/util/prelude.h"
3+
4+
namespace trtorch {
5+
namespace core {
6+
namespace ir {
7+
8+
InputRange::InputRange(std::vector<int64_t> d) {
9+
if (d.size() > 5) {
10+
LOG_WARNING("Verify that this dim size is accepted");
11+
}
12+
13+
opt = util::toDims(d);
14+
min = util::toDims(d);
15+
max = util::toDims(d);
16+
input_shape = util::toDims(d);
17+
input_is_dynamic = false;
18+
}
19+
20+
InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape) {
21+
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
22+
LOG_WARNING("Verify that this dim size is accepted");
23+
}
24+
25+
std::set<size_t> sizes;
26+
sizes.insert(min_shape.size());
27+
sizes.insert(opt_shape.size());
28+
sizes.insert(max_shape.size());
29+
30+
if (sizes.size() != 1) {
31+
LOG_ERROR(
32+
"Expected all input sizes have the same dimensions, but found dimensions: min("
33+
<< min_shape.size() << "), opt(" << opt_shape.size() << "), max(" << max_shape.size() << ")");
34+
}
35+
36+
min = util::toDims(min_shape);
37+
opt = util::toDims(opt_shape);
38+
max = util::toDims(max_shape);
39+
40+
std::vector<int64_t> dyn_shape;
41+
for (size_t i = 0; i < opt_shape.size(); i++) {
42+
std::set<uint64_t> dim;
43+
dim.insert(min_shape[i]);
44+
dim.insert(opt_shape[i]);
45+
dim.insert(max_shape[i]);
46+
if (dim.size() != 1) {
47+
dyn_shape.push_back(-1);
48+
input_is_dynamic = true;
49+
} else {
50+
dyn_shape.push_back(opt_shape[i]);
51+
}
52+
}
53+
54+
input_shape = util::toDims(dyn_shape);
55+
}
56+
57+
} // namespace ir
58+
} // namespace core
59+
} // namespace trtorch

core/ir/ir.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#include <vector>
4+
#include "NvInfer.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace ir {
9+
10+
struct InputRange {
11+
nvinfer1::Dims min;
12+
nvinfer1::Dims max;
13+
nvinfer1::Dims opt;
14+
nvinfer1::Dims input_shape;
15+
bool input_is_dynamic = false;
16+
// Should we restrict to unsigned?
17+
InputRange(std::vector<int64_t> d);
18+
InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
19+
};
20+
21+
} // namespace ir
22+
} // namespace core
23+
} // namespace trtorch

core/partitioning/BUILD

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cc_library(
1212
hdrs = [
1313
"SegmentedBlock.h",
1414
"shape_analysis.h",
15+
"PartitionInfo.h",
1516
"partitioning.h",
1617
],
1718
srcs = [
@@ -20,8 +21,9 @@ cc_library(
2021
"partitioning.cpp",
2122
],
2223
deps = [
23-
"//core/conversion",
2424
"//core/util:prelude",
25+
"//core/ir",
26+
"//core/conversion",
2527
"//core/lowering"
2628
] + select({
2729
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
@@ -35,6 +37,11 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
3537
pkg_tar(
3638
name = "include",
3739
package_dir = "core/partitioning/",
38-
srcs = ["partitioning.h"],
40+
srcs = [
41+
"SegmentedBlock.h",
42+
"shape_analysis.h",
43+
"PartitionInfo.h",
44+
"partitioning.h",
45+
],
3946
)
4047

core/partitioning/PartitionInfo.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <vector>
5+
#include <string>
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace partitioning {
10+
11+
struct PartitionInfo {
12+
bool enabled = false;
13+
uint64_t min_block_size = 1;
14+
std::vector<std::string> forced_fallback_operators;
15+
};
16+
17+
} // namespace partitioning
18+
} // namespace core
19+
} // namespace trtorch

0 commit comments

Comments
 (0)