@@ -32,6 +32,7 @@ int main(int argc, const char* argv[]) {
32
32
// Create the calibration dataset
33
33
const std::string data_dir = std::string (argv[2 ]);
34
34
auto calibration_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
35
+ .use_subset (320 )
35
36
.map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 },
36
37
{0.2023 , 0.1994 , 0.2010 }))
37
38
.map (torch::data::transforms::Stack<>());
@@ -41,19 +42,19 @@ int main(int argc, const char* argv[]) {
41
42
42
43
std::string calibration_cache_file = " /tmp/vgg16_TRT_ptq_calibration.cache" ;
43
44
44
- auto calibrator = trtorch::ptq::make_int8_calibrator (std::move (calibration_dataloader), calibration_cache_file, true );
45
+ auto calibrator = trtorch::ptq::make_int8_calibrator (std::move (calibration_dataloader), calibration_cache_file, false );
45
46
// auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);
46
47
47
48
48
49
std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
49
50
// Configure settings for compilation
50
51
auto extra_info = trtorch::ExtraInfo ({input_shape});
51
52
// Set operating precision to INT8
52
- extra_info.op_precision = torch::kChar ;
53
+ extra_info.op_precision = torch::kFI8 ;
53
54
// Use the TensorRT Entropy Calibrator
54
55
extra_info.ptq_calibrator = calibrator;
55
- // Increase the default workspace size;
56
- extra_info.workspace_size = 1 << 30 ;
56
+ // Set max batch size for the engine
57
+ extra_info.max_batch_size = 32 ;
57
58
58
59
mod.eval ();
59
60
@@ -92,6 +93,14 @@ int main(int argc, const char* argv[]) {
92
93
93
94
auto outputs = trt_mod.forward ({images});
94
95
auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
96
+ predictions = predictions.reshape (predictions.sizes ()[0 ]);
97
+
98
+ if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
99
+ // To handle smaller batches util Optimization profiles work
100
+ predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
101
+ }
102
+
103
+ std:: cout << predictions << targets << std::endl;
95
104
96
105
total += targets.sizes ()[0 ];
97
106
correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
0 commit comments