Skip to content

Update PTQ example to fix new compile_spec requirements #1242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions docsrc/tutorials/ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,16 @@ a TensorRT calibrator by providing desired configuration. The following code dem
algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device('cuda:0'))

compile_spec = {
"inputs": [torch_tensorrt.Input((1, 3, 32, 32))],
"enabled_precisions": {torch.float, torch.half, torch.int8},
"calibrator": calibrator,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}
trt_mod = torch_tensorrt.compile(model, compile_spec)
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 32, 32))],
enabled_precisions={torch.float, torch.half, torch.int8},
calibrator=calibrator,
device={
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
})

In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
to use ``CacheCalibrator`` to use in INT8 mode.
Expand All @@ -188,13 +185,9 @@ to use ``CacheCalibrator`` to use in INT8 mode.

calibrator = torch_tensorrt.ptq.CacheCalibrator("./calibration.cache")

compile_settings = {
"inputs": [torch_tensorrt.Input([1, 3, 32, 32])],
"enabled_precisions": {torch.float, torch.half, torch.int8},
"calibrator": calibrator,
}

trt_mod = torch_tensorrt.compile(model, compile_settings)
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input([1, 3, 32, 32])],
enabled_precisions={torch.float, torch.half, torch.int8},
calibrator=calibrator)

If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py
Expand Down