@@ -18,6 +18,9 @@ struct InputRange {
18
18
std::vector<int64_t > max;
19
19
20
20
core::conversion::InputRange toInternalInputRange () {
21
+ for (auto o : opt) {
22
+ std::cout << o << std::endl;
23
+ }
21
24
return core::conversion::InputRange (min, opt, max);
22
25
}
23
26
};
@@ -76,7 +79,11 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
76
79
struct ExtraInfo {
77
80
78
81
core::ExtraInfo toInternalExtraInfo () {
79
- auto info = core::ExtraInfo (input_ranges);
82
+ std::cout << " HELLO" << input_ranges.size () << std::endl;
83
+ for (auto i : input_ranges) {
84
+ internal_input_ranges.push_back (i.toInternalInputRange ());
85
+ }
86
+ auto info = core::ExtraInfo (internal_input_ranges);
80
87
info.convert_info .engine_settings .op_precision = toTRTDataType (op_precision);
81
88
info.convert_info .engine_settings .refit = refit;
82
89
info.convert_info .engine_settings .debug = debug;
@@ -91,7 +98,8 @@ struct ExtraInfo {
91
98
return info;
92
99
}
93
100
94
- std::vector<core::conversion::InputRange> input_ranges;
101
+ std::vector<InputRange> input_ranges;
102
+ std::vector<core::conversion::InputRange> internal_input_ranges;
95
103
DataType op_precision = DataType::kFloat ;
96
104
bool refit = false ;
97
105
bool debug = false ;
@@ -112,10 +120,10 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, ExtraInfo& info)
112
120
return trt_mod;
113
121
}
114
122
115
- std::string ConvertGraphToTRTEngine (const torch::jit::Module& mod, const std::string& method_name, ExtraInfo& info) {
123
+ py::bytes ConvertGraphToTRTEngine (const torch::jit::Module& mod, const std::string& method_name, ExtraInfo& info) {
116
124
py::gil_scoped_acquire gil;
117
125
auto trt_engine = core::ConvertGraphToTRTEngine (mod, method_name, info.toInternalExtraInfo ());
118
- return trt_engine;
126
+ return py::bytes ( trt_engine) ;
119
127
}
120
128
121
129
bool CheckMethodOperatorSupport (const torch::jit::Module& module , const std::string& method_name) {
@@ -136,11 +144,7 @@ PYBIND11_MODULE(_C, m) {
136
144
.def (py::init<>())
137
145
.def_readwrite (" min" , &InputRange::min)
138
146
.def_readwrite (" opt" , &InputRange::opt)
139
- .def_readwrite (" max" , &InputRange::max)
140
- .def (" _to_internal_input_range" , &InputRange::toInternalInputRange);
141
-
142
- // py::class_<core::conversion::InputRange>(m, "_InternalInputRange")
143
- // .def(py::init<>());
147
+ .def_readwrite (" max" , &InputRange::max);
144
148
145
149
py::enum_<DataType>(m, " dtype" )
146
150
.value (" float" , DataType::kFloat )
0 commit comments