4
4
namespace trtorch {
5
5
namespace pyapi {
6
6
7
- std::string to_str (InputRange& value ) {
7
+ std::string InputRange:: to_str () {
8
8
auto vec_to_str = [](std::vector<int64_t > shape) -> std::string {
9
9
std::stringstream ss;
10
10
ss << ' [' ;
@@ -17,9 +17,9 @@ std::string to_str(InputRange& value) {
17
17
18
18
std::stringstream ss;
19
19
ss << " {" << std::endl;
20
- ss << " min: " << vec_to_str (value. min ) << ' ,' << std::endl;
21
- ss << " opt: " << vec_to_str (value. opt ) << ' ,' << std::endl;
22
- ss << " max: " << vec_to_str (value. max ) << ' ,' << std::endl;
20
+ ss << " min: " << vec_to_str (min) << ' ,' << std::endl;
21
+ ss << " opt: " << vec_to_str (opt) << ' ,' << std::endl;
22
+ ss << " max: " << vec_to_str (max) << ' ,' << std::endl;
23
23
ss << " }" << std::endl;
24
24
return ss.str ();
25
25
}
@@ -68,6 +68,18 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value) {
68
68
}
69
69
}
70
70
71
+ std::string Device::to_str () {
72
+ std::stringstream ss;
73
+ std::string fallback = allow_gpu_fallback ? " True" : " False" ;
74
+ ss << " {" << std::endl;
75
+ ss << " \" device_type\" : " << pyapi::to_str (device_type) << std::endl;
76
+ ss << " \" allow_gpu_fallback\" : " << fallback << std::endl;
77
+ ss << " \" gpu_id\" : " << gpu_id << std::endl;
78
+ ss << " \" dla_core\" : " << dla_core << std::endl;
79
+ ss << " }" << std::endl;
80
+ return ss.str ();
81
+ }
82
+
71
83
std::string to_str (EngineCapability value) {
72
84
switch (value) {
73
85
case EngineCapability::kSAFE_GPU :
@@ -92,6 +104,21 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
92
104
}
93
105
}
94
106
107
+ std::string TorchFallback::to_str () {
108
+ std::stringstream ss;
109
+ std::string e = enabled ? " True" : " False" ;
110
+ ss << " {" << std::endl;
111
+ ss << " \" enabled\" : " << e << std::endl;
112
+ ss << " \" min_block_size\" : " << min_block_size << std::endl;
113
+ ss << " \" forced_fallback_operators\" : [" << std::endl;
114
+ for (auto i : forced_fallback_operators) {
115
+ ss << " " << i << ' ,' << std::endl;
116
+ }
117
+ ss << " ]" << std::endl;
118
+ ss << " }" << std::endl;
119
+ return ss.str ();
120
+ }
121
+
95
122
core::CompileSpec CompileSpec::toInternalCompileSpec () {
96
123
std::vector<core::ir::InputRange> internal_input_ranges;
97
124
for (auto i : input_ranges) {
@@ -128,36 +155,25 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
128
155
std::string CompileSpec::stringify () {
129
156
std::stringstream ss;
130
157
ss << " TensorRT Compile Spec: {" << std::endl;
131
- ss << " \" Input Shapes\" : [" << std::endl;
158
+ ss << " \" Input Shapes\" : [" << std::endl;
132
159
for (auto i : input_ranges) {
133
- ss << to_str (i );
160
+ ss << i. to_str ();
134
161
}
135
162
std::string enabled = torch_fallback.enabled ? " True" : " False" ;
136
- ss << " ]" << std::endl;
137
- ss << " \" Op Precision\" : " << to_str (op_precision) << std::endl;
138
- ss << " \" TF32 Disabled\" : " << disable_tf32 << std::endl;
139
- ss << " \" Refit\" : " << refit << std::endl;
140
- ss << " \" Debug\" : " << debug << std::endl;
141
- ss << " \" Strict Types\" : " << strict_types << std::endl;
142
- ss << " \" Device Type: " << to_str (device.device_type ) << std::endl;
143
- ss << " \" GPU ID: " << device.gpu_id << std::endl;
144
- ss << " \" DLA Core: " << device.dla_core << std::endl;
145
- ss << " \" Allow GPU Fallback\" : " << device.allow_gpu_fallback << std::endl;
146
- ss << " \" Engine Capability\" : " << to_str (capability) << std::endl;
147
- ss << " \" Num Min Timing Iters\" : " << num_min_timing_iters << std::endl;
148
- ss << " \" Num Avg Timing Iters\" : " << num_avg_timing_iters << std::endl;
149
- ss << " \" Workspace Size\" : " << workspace_size << std::endl;
150
- ss << " \" Max Batch Size\" : " << max_batch_size << std::endl;
151
- ss << " \" Truncate long and double\" : " << truncate_long_and_double << std::endl;
152
- ss << " \" Torch Fallback: {" << std::endl;
153
- ss << " \" enabled\" : " << enabled << std::endl;
154
- ss << " \" min_block_size\" : " << torch_fallback.min_block_size << std::endl;
155
- ss << " \" forced_fallback_operators\" : [" << std::endl;
156
- for (auto i : torch_fallback.forced_fallback_operators ) {
157
- ss << " " << i << ' ,' << std::endl;
158
- }
159
- ss << " ]" << std::endl;
160
- ss << " }" << std::endl;
163
+ ss << " ]" << std::endl;
164
+ ss << " \" Op Precision\" : " << to_str (op_precision) << std::endl;
165
+ ss << " \" TF32 Disabled\" : " << disable_tf32 << std::endl;
166
+ ss << " \" Refit\" : " << refit << std::endl;
167
+ ss << " \" Debug\" : " << debug << std::endl;
168
+ ss << " \" Strict Types\" : " << strict_types << std::endl;
169
+ ss << " \" Device\" : " << device.to_str () << std::endl;
170
+ ss << " \" Engine Capability\" : " << to_str (capability) << std::endl;
171
+ ss << " \" Num Min Timing Iters\" : " << num_min_timing_iters << std::endl;
172
+ ss << " \" Num Avg Timing Iters\" : " << num_avg_timing_iters << std::endl;
173
+ ss << " \" Workspace Size\" : " << workspace_size << std::endl;
174
+ ss << " \" Max Batch Size\" : " << max_batch_size << std::endl;
175
+ ss << " \" Truncate long and double\" : " << truncate_long_and_double << std::endl;
176
+ ss << " \" Torch Fallback\" : " << torch_fallback.to_str ();
161
177
ss << " }" ;
162
178
return ss.str ();
163
179
}
0 commit comments