Skip to content

Commit 2f86f84

Browse files
committed
feat(//cpp/api): Remove the extra includes in the API header
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent f022dfe commit 2f86f84

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

cpp/api/include/trtorch/ptq.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ class IInt8Calibrator;
1111
class IInt8EntropyCalibrator2;
1212
}
1313

14+
namespace torch {
15+
namespace data {
16+
template<typename Example>
17+
class Iterator;
18+
}
19+
}
20+
1421
namespace trtorch {
1522
namespace ptq {
1623

cpp/api/include/trtorch/trtorch.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
#include <vector>
1313
#include <memory>
1414

15-
#include "torch/torch.h"
16-
#include "NvInfer.h"
17-
1815
// Just include the .h?
1916
namespace torch {
2017
namespace jit {

cpp/ptq/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ int main(int argc, const char* argv[]) {
4141

4242
std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache";
4343

44-
//auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
45-
auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);
44+
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
45+
//auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);
4646

4747

4848
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};

0 commit comments

Comments
 (0)