@@ -29,13 +29,11 @@ class ArrayRef;
29
29
}
30
30
31
31
namespace nvinfer1 {
32
- class IInt8EntropyCalibrator2 ;
32
+ class IInt8Calibrator ;
33
33
}
34
34
#endif // DOXYGEN_SHOULD_SKIP_THIS
35
35
36
36
#include " trtorch/macros.h"
37
- #include " trtorch/logging.h"
38
- #include " trtorch/ptq.h"
39
37
namespace trtorch {
40
38
/* *
41
39
* Settings data structure for TRTorch compilation
@@ -406,50 +404,4 @@ TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, Ex
406
404
* @return: std::string: Serialized TensorRT engine equivilant to the method graph
407
405
*/
408
406
TRTORCH_API std::string ConvertGraphToTRTEngine (const torch::jit::Module& module , std::string method_name, ExtraInfo info);
409
-
410
- namespace ptq {
411
- /* *
412
- * @brief A factory to build a post training quantization calibrator from a torch dataloader
413
- *
414
- * Creates a calibrator to use for post training quantization. By default the returned calibrator uses TensorRT Entropy v2
415
- * algorithm to perform calibration. This is recommended for feed forward networks. You can override the algorithm selection
416
- * (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with the calibrator class
417
- * as a template parameter.
418
- *
419
- * e.g. ``trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, use_cache);``
420
- * @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
421
- * @tparam DataLoader: std::unique_ptr<torch::data::DataLoader> - DataLoader type
422
- * @param dataloader: std::unique_ptr<torch::data::DataLoader> - DataLoader containing data
423
- * @param cache_file_path: const std::string& - Path to read/write calibration cache
424
- * @param use_cache: bool - use calibration cache
425
- * @return Int8Calibrator<Algorithm, DataLoader>
426
- */
427
-
428
- template <typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
429
- TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator (DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
430
- return Int8Calibrator<Algorithm, DataLoader>(std::move (dataloader), cache_file_path, use_cache);
431
- }
432
-
433
- /* *
434
- * @brief A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache
435
- *
436
- * Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore
437
- * you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later
438
- * in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also
439
- * be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this.
440
- *
441
- * By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
442
- * You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
443
- * the calibrator class as a template parameter.
444
- *
445
- * e.g. trtorch::ptq::make_int8_cache_calibrator<nvinfer1::IInt8MinMaxCalibrator>(calibration_cache_file);
446
- * @tparam Algorithm: class nvinfer1::IInt8Calibrator (Default: nvinfer1::IInt8EntropyCalibrator2) - Algorithm to use
447
- * @param cache_file_path: const std::string& - Path to read/write calibration cache
448
- * @return Int8CacheCalibrator<Algorithm>
449
- */
450
- template <typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
451
- TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator (const std::string& cache_file_path) {
452
- return Int8CacheCalibrator<Algorithm>(cache_file_path);
453
- }
454
- } // namespace ptq
455
407
} // namespace trtorch
0 commit comments