Skip to content

Commit 5f36f47

Browse files
committed
feat(//cpp/ptq): Add a feature to the dataset to use less than the full
test set for calibration Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4a8dc6e commit 5f36f47

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

cpp/ptq/datasets/cifar10.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "cpp/ptq/datasets/cifar10.h"
22

3+
#include "torch/torch.h"
34
#include "torch/data/example.h"
45
#include "torch/types.h"
56

@@ -63,15 +64,19 @@ std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) {
6364
}
6465

6566
std::pair<torch::Tensor, torch::Tensor> read_train_data(const std::string& root) {
66-
torch::Tensor images, targets;
67+
std::vector<torch::Tensor> images, targets;
6768
for(uint32_t i = 1; i <= 5; i++) {
6869
std::stringstream ss;
6970
ss << root << '/' << kTrainFilenamePrefix << i << ".bin";
7071
auto batch = read_batch(ss.str());
71-
images = torch::stack({images, batch.first});
72-
targets = torch::stack({targets, batch.second});
72+
images.push_back(batch.first);
73+
targets.push_back(batch.second);
7374
}
74-
return std::make_pair(images, targets);
75+
76+
torch::Tensor image_tensor = std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);});
77+
torch::Tensor target_tensor = std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);});
78+
79+
return std::make_pair(image_tensor, target_tensor);
7580
}
7681

7782
std::pair<torch::Tensor, torch::Tensor> read_test_data(const std::string& root) {
@@ -93,6 +98,7 @@ CIFAR10::CIFAR10(const std::string& root, Mode mode)
9398

9499
images_ = std::move(data.first);
95100
targets_ = std::move(data.second);
101+
assert(images_.sizes()[0] == images_.sizes()[0]);
96102
}
97103

98104
torch::data::Example<> CIFAR10::get(size_t index) {
@@ -115,5 +121,12 @@ const torch::Tensor& CIFAR10::targets() const {
115121
return targets_;
116122
}
117123

124+
CIFAR10&& CIFAR10::use_subset(int64_t new_size) {
125+
assert(new_size <= images_.sizes()[0]);
126+
images_ = images_.slice(0, 0, new_size);
127+
targets_ = targets_.slice(0, 0, new_size);
128+
return std::move(*this);
129+
}
130+
118131
} // namespace datasets
119132

cpp/ptq/datasets/cifar10.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
3434
// Returns all targets stacked into a single tensor
3535
const torch::Tensor& targets() const;
3636

37+
// Trims the dataset to the first n pairs
38+
CIFAR10&& use_subset(int64_t new_size);
39+
40+
3741
private:
3842
Mode mode_;
3943
torch::Tensor images_, targets_;

cpp/ptq/main.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ int main(int argc, const char* argv[]) {
3232
// Create the calibration dataset
3333
const std::string data_dir = std::string(argv[2]);
3434
auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest)
35+
.use_subset(320)
3536
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465},
3637
{0.2023, 0.1994, 0.2010}))
3738
.map(torch::data::transforms::Stack<>());
@@ -41,19 +42,19 @@ int main(int argc, const char* argv[]) {
4142

4243
std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache";
4344

44-
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
45+
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, false);
4546
//auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);
4647

4748

4849
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
4950
// Configure settings for compilation
5051
auto extra_info = trtorch::ExtraInfo({input_shape});
5152
// Set operating precision to INT8
52-
extra_info.op_precision = torch::kChar;
53+
extra_info.op_precision = torch::kFI8;
5354
// Use the TensorRT Entropy Calibrator
5455
extra_info.ptq_calibrator = calibrator;
55-
// Increase the default workspace size;
56-
extra_info.workspace_size = 1 << 30;
56+
// Set max batch size for the engine
57+
extra_info.max_batch_size = 32;
5758

5859
mod.eval();
5960

@@ -92,6 +93,14 @@ int main(int argc, const char* argv[]) {
9293

9394
auto outputs = trt_mod.forward({images});
9495
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
96+
predictions = predictions.reshape(predictions.sizes()[0]);
97+
98+
if (predictions.sizes()[0] != targets.sizes()[0]) {
99+
// To handle smaller batches util Optimization profiles work
100+
predictions = predictions.slice(0, 0, targets.sizes()[0]);
101+
}
102+
103+
std:: cout << predictions << targets << std::endl;
95104

96105
total += targets.sizes()[0];
97106
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();

0 commit comments

Comments
 (0)