Skip to content

Commit d2f8a59

Browse files
committed
refactor(//cpp/api)!: Refactoring ptq to use includes but seperate from
the core header BREAKING CHANGE: To use ptq you now need to include trtorch/ptq.h in addition to trtorch/trtorch.h, similarly for logging commands you need to include trtorch/logging.h Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 842a567 commit d2f8a59

File tree

5 files changed

+71
-53
lines changed

5 files changed

+71
-53
lines changed

cpp/api/include/trtorch/logging.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) NVIDIA Corporation.
3+
* All rights reserved.
4+
*
5+
* This library is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
18
#pragma once
29

310
#include <string>

cpp/api/include/trtorch/macros.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
* Copyright (c) NVIDIA Corporation.
3+
* All rights reserved.
4+
*
5+
* This library is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
18
#pragma once
29

310
#if defined(__GNUC__)

cpp/api/include/trtorch/ptq.h

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
1+
/*
2+
* Copyright (c) NVIDIA Corporation.
3+
* All rights reserved.
4+
*
5+
* This library is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
18
#pragma once
29

310
#include <string>
411
#include <vector>
512
#include <memory>
613
#include <iostream>
14+
#include <fstream>
15+
#include <iterator>
716
#include <sstream>
817

18+
#include "torch/torch.h"
919
#include "trtorch/logging.h"
20+
#include "NvInfer.h"
1021

1122
#ifndef DOXYGEN_SHOULD_SKIP_THIS
1223
namespace nvinfer1 {
1324
class IInt8Calibrator;
1425
class IInt8EntropyCalibrator2;
1526
}
1627

17-
namespace torch {
18-
class Tensor;
19-
}
20-
2128
namespace trtorch {
2229
namespace ptq {
2330
bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data);
@@ -269,5 +276,49 @@ class Int8CacheCalibrator : Algorithm {
269276
std::vector<char> cache_;
270277
};
271278

279+
/**
280+
* @brief A factory to build a post training quantization calibrator from a torch dataloader
281+
*
282+
* Creates a calibrator to use for post training quantization. By default the returned calibrator uses TensorRT Entropy v2
283+
* algorithm to perform calibration. This is recommended for feed forward networks. You can override the algorithm selection
284+
* (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with the calibrator class
285+
* as a template parameter.
286+
*
287+
* e.g. ``trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, use_cache);``
288+
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
289+
* @tparam DataLoader: std::unique_ptr<torch::data::DataLoader> - DataLoader type
290+
* @param dataloader: std::unique_ptr<torch::data::DataLoader> - DataLoader containing data
291+
* @param cache_file_path: const std::string& - Path to read/write calibration cache
292+
* @param use_cache: bool - use calibration cache
293+
* @return Int8Calibrator<Algorithm, DataLoader>
294+
*/
295+
296+
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
297+
TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
298+
return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
299+
}
300+
301+
/**
302+
* @brief A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache
303+
*
304+
* Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore
305+
* you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later
306+
* in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also
307+
* be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this.
308+
*
309+
* By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
310+
* You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
311+
* the calibrator class as a template parameter.
312+
*
313+
* e.g. trtorch::ptq::make_int8_cache_calibrator<nvinfer1::IInt8MinMaxCalibrator>(calibration_cache_file);
314+
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
315+
* @param cache_file_path: const std::string& - Path to read/write calibration cache
316+
* @return Int8CacheCalibrator<Algorithm>
317+
*/
318+
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
319+
TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
320+
return Int8CacheCalibrator<Algorithm>(cache_file_path);
321+
}
322+
272323
} // namespace ptq
273324
} // namespace trtorch

cpp/api/include/trtorch/trtorch.h

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,11 @@ class ArrayRef;
2929
}
3030

3131
namespace nvinfer1 {
32-
class IInt8EntropyCalibrator2;
32+
class IInt8Calibrator;
3333
}
3434
#endif //DOXYGEN_SHOULD_SKIP_THIS
3535

3636
#include "trtorch/macros.h"
37-
#include "trtorch/logging.h"
38-
#include "trtorch/ptq.h"
3937
namespace trtorch {
4038
/**
4139
* Settings data structure for TRTorch compilation
@@ -406,50 +404,4 @@ TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, Ex
406404
* @return: std::string: Serialized TensorRT engine equivilant to the method graph
407405
*/
408406
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info);
409-
410-
namespace ptq {
411-
/**
412-
* @brief A factory to build a post training quantization calibrator from a torch dataloader
413-
*
414-
* Creates a calibrator to use for post training quantization. By default the returned calibrator uses TensorRT Entropy v2
415-
* algorithm to perform calibration. This is recommended for feed forward networks. You can override the algorithm selection
416-
* (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with the calibrator class
417-
* as a template parameter.
418-
*
419-
* e.g. ``trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, use_cache);``
420-
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
421-
* @tparam DataLoader: std::unique_ptr<torch::data::DataLoader> - DataLoader type
422-
* @param dataloader: std::unique_ptr<torch::data::DataLoader> - DataLoader containing data
423-
* @param cache_file_path: const std::string& - Path to read/write calibration cache
424-
* @param use_cache: bool - use calibration cache
425-
* @return Int8Calibrator<Algorithm, DataLoader>
426-
*/
427-
428-
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
429-
TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
430-
return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
431-
}
432-
433-
/**
434-
* @brief A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache
435-
*
436-
* Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore
437-
* you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later
438-
* in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also
439-
* be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this.
440-
*
441-
* By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
442-
* You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
443-
* the calibrator class as a template parameter.
444-
*
445-
* e.g. trtorch::ptq::make_int8_cache_calibrator<nvinfer1::IInt8MinMaxCalibrator>(calibration_cache_file);
446-
* @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
447-
* @param cache_file_path: const std::string& - Path to read/write calibration cache
448-
* @return Int8CacheCalibrator<Algorithm>
449-
*/
450-
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
451-
TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
452-
return Int8CacheCalibrator<Algorithm>(cache_file_path);
453-
}
454-
} // namespace ptq
455407
} // namespace trtorch

cpp/ptq/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "torch/script.h"
22
#include "torch/torch.h"
33
#include "trtorch/trtorch.h"
4+
#include "trtorch/ptq.h"
45

56
#include "NvInfer.h"
67

0 commit comments

Comments
 (0)