Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fe7fc94

Browse files
peri044zewenli98narendasan
authoredMay 30, 2024
feat: Implement FP8 functionality (#2763)
Co-authored-by: Evan Li <[email protected]> Co-authored-by: Naren Dasan <[email protected]>
1 parent 856f33d commit fe7fc94

30 files changed

+522
-73
lines changed
 
Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1+
#!/usr/bin/env bash
12
set -eou pipefail
2-
source "${BUILD_ENV_FILE}"
3+
# Source conda so it's available to the script environment
4+
source ${BUILD_ENV_FILE}
5+
export EXTRA_INDEX_URL="https://download.pytorch.org/whl/test/${CU_VERSION}"
6+
# Install all the dependencies required for Torch-TensorRT
7+
${CONDA_RUN} pip install --pre -r ${PWD}/tests/py/requirements.txt --use-deprecated=legacy-resolver --extra-index-url=${EXTRA_INDEX_URL}
38

4-
# Install test index version of Torch and Torchvision
5-
${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision
6-
${CONDA_RUN} pip install pyyaml mpmath==1.3.0
7-
8-
# Install TRT 10 from PyPi
9-
${CONDA_RUN} pip install tensorrt==10.0.0b6 tensorrt-${CU_VERSION::4}-bindings==10.0.0b6 tensorrt-${CU_VERSION::4}-libs==10.0.0b6 --extra-index-url https://pypi.nvidia.com
10-
11-
# Install pre-built Torch-TRT
9+
# Install Torch-TensorRT via pre-built wheels. On windows, the location of wheels is not fixed.
1210
${CONDA_RUN} pip install ${RUNNER_ARTIFACT_DIR}/torch_tensorrt*.whl
1311

14-
echo -e "Running test script";
12+
echo -e "Running test script";

‎.github/scripts/install-torch-tensorrt.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
set -eou pipefail
33
# Source conda so it's available to the script environment
44
source ${BUILD_ENV_FILE}
5-
${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision
6-
${CONDA_RUN} python -m pip install pyyaml mpmath==1.3.0
7-
export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()")
5+
export EXTRA_INDEX_URL="https://download.pytorch.org/whl/test/${CU_VERSION}"
6+
# Install all the dependencies required for Torch-TensorRT
7+
${CONDA_RUN} pip install --pre -r ${PWD}/tests/py/requirements.txt --use-deprecated=legacy-resolver --extra-index-url=${EXTRA_INDEX_URL}
88

9-
# Install Torch-TensorRT
10-
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com
9+
# Install Torch-TensorRT via pre-built wheels. On windows, the location of wheels is not fixed.
10+
${CONDA_RUN} pip install /opt/torch-tensorrt-builds/torch_tensorrt*.whl
1111

12-
echo -e "Running test script";
12+
echo -e "Running test script";

‎.github/workflows/build-test-linux.yml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ jobs:
6666
package-name: torch_tensorrt
6767
pre-script: packaging/pre_build_script.sh
6868
post-script: packaging/post_build_script.sh
69+
smoke-test-script: packaging/smoke_test_script.sh
6970
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
7071
with:
7172
job-name: tests-py-torchscript-fe
@@ -80,13 +81,10 @@ jobs:
8081
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
8182
pushd .
8283
cd tests/modules
83-
# Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now.
84-
${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers==4.39.3 timm==0.9.16 pybind11==2.6.2
8584
${CONDA_RUN} python hub.py
8685
popd
8786
pushd .
8887
cd tests/py/ts
89-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
9088
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/
9189
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/
9290
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/
@@ -103,6 +101,7 @@ jobs:
103101
package-name: torch_tensorrt
104102
pre-script: packaging/pre_build_script.sh
105103
post-script: packaging/post_build_script.sh
104+
smoke-test-script: packaging/smoke_test_script.sh
106105
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
107106
with:
108107
job-name: tests-py-dynamo-converters
@@ -116,7 +115,6 @@ jobs:
116115
export USE_HOST_DEPS=1
117116
pushd .
118117
cd tests/py/dynamo
119-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
120118
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
121119
popd
122120
@@ -131,6 +129,7 @@ jobs:
131129
package-name: torch_tensorrt
132130
pre-script: packaging/pre_build_script.sh
133131
post-script: packaging/post_build_script.sh
132+
smoke-test-script: packaging/smoke_test_script.sh
134133
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
135134
with:
136135
job-name: tests-py-dynamo-fe
@@ -144,7 +143,6 @@ jobs:
144143
export USE_HOST_DEPS=1
145144
pushd .
146145
cd tests/py/dynamo
147-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
148146
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
149147
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
150148
popd
@@ -160,6 +158,7 @@ jobs:
160158
package-name: torch_tensorrt
161159
pre-script: packaging/pre_build_script.sh
162160
post-script: packaging/post_build_script.sh
161+
smoke-test-script: packaging/smoke_test_script.sh
163162
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
164163
with:
165164
job-name: tests-py-dynamo-serde
@@ -173,7 +172,6 @@ jobs:
173172
export USE_HOST_DEPS=1
174173
pushd .
175174
cd tests/py/dynamo
176-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
177175
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
178176
popd
179177
@@ -188,6 +186,7 @@ jobs:
188186
package-name: torch_tensorrt
189187
pre-script: packaging/pre_build_script.sh
190188
post-script: packaging/post_build_script.sh
189+
smoke-test-script: packaging/smoke_test_script.sh
191190
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
192191
with:
193192
job-name: tests-py-torch-compile-be
@@ -201,7 +200,6 @@ jobs:
201200
export USE_HOST_DEPS=1
202201
pushd .
203202
cd tests/py/dynamo
204-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
205203
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
206204
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
207205
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
@@ -218,6 +216,7 @@ jobs:
218216
package-name: torch_tensorrt
219217
pre-script: packaging/pre_build_script.sh
220218
post-script: packaging/post_build_script.sh
219+
smoke-test-script: packaging/smoke_test_script.sh
221220
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
222221
with:
223222
job-name: tests-py-dynamo-core
@@ -231,7 +230,6 @@ jobs:
231230
export USE_HOST_DEPS=1
232231
pushd .
233232
cd tests/py/dynamo
234-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
235233
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
236234
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
237235
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/
@@ -247,7 +245,9 @@ jobs:
247245
- repository: pytorch/tensorrt
248246
package-name: torch_tensorrt
249247
pre-script: packaging/pre_build_script.sh
250-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
248+
post-script: packaging/post_build_script.sh
249+
smoke-test-script: packaging/smoke_test_script.sh
250+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
251251
with:
252252
job-name: tests-py-core
253253
repository: "pytorch/tensorrt"
@@ -260,6 +260,5 @@ jobs:
260260
export USE_HOST_DEPS=1
261261
pushd .
262262
cd tests/py/core
263-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
264263
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .
265264
popd

‎.github/workflows/build-test-windows.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ jobs:
7272
export USE_HOST_DEPS=1
7373
pushd .
7474
cd tests/py/dynamo
75-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
7675
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
7776
popd
7877
@@ -98,7 +97,6 @@ jobs:
9897
export USE_HOST_DEPS=1
9998
pushd .
10099
cd tests/py/dynamo
101-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
102100
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
103101
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
104102
popd
@@ -125,7 +123,6 @@ jobs:
125123
export USE_HOST_DEPS=1
126124
pushd .
127125
cd tests/py/dynamo
128-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
129126
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
130127
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
131128
popd
@@ -152,7 +149,6 @@ jobs:
152149
export USE_HOST_DEPS=1
153150
pushd .
154151
cd tests/py/dynamo
155-
${CONDA_RUN} python -m pip install --pre -r ../requirements.txt
156152
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
157153
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
158154
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/

‎dev_dep_versions.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@ __version__: "2.3.0"
22
__cuda_version__: "12.1"
33
__cudnn_version__: "8.9"
44
__tensorrt_version__: "10.0.1"
5+
__torch_version__: "2.3.0"
6+
# torchvision version here is not a direct dependency but the one used during testing
7+
__torchvision_version__: "0.18.0"
8+
__index_url__: "https://download.pytorch.org/whl/test/"

‎docsrc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Tutorials
111111
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
112112
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
113113
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
114+
tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq
114115

115116
Python API Documenation
116117
------------------------

‎examples/dynamo/README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ a number of ways you can leverage this backend to accelerate inference.
1010
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
1212
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
13+
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``

‎examples/dynamo/vgg16_fp8_ptq.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
.. _vgg16_fp8_ptq:
3+
4+
Torch Compile VGG16 with FP8 and PTQ
5+
======================================================
6+
7+
This script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a VGG16 model with FP8 and PTQ.
8+
"""
9+
10+
# %%
11+
# Imports and Model Definition
12+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13+
14+
import argparse
15+
16+
import modelopt.torch.quantization as mtq
17+
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
20+
import torch_tensorrt as torchtrt
21+
import torchvision.datasets as datasets
22+
import torchvision.transforms as transforms
23+
from modelopt.torch.quantization.utils import export_torch_mode
24+
25+
26+
class VGG(nn.Module):
27+
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
28+
super(VGG, self).__init__()
29+
30+
layers = []
31+
in_channels = 3
32+
for l in layer_spec:
33+
if l == "pool":
34+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
35+
else:
36+
layers += [
37+
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
38+
nn.BatchNorm2d(l),
39+
nn.ReLU(),
40+
]
41+
in_channels = l
42+
43+
self.features = nn.Sequential(*layers)
44+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
45+
self.classifier = nn.Sequential(
46+
nn.Linear(512 * 1 * 1, 4096),
47+
nn.ReLU(),
48+
nn.Dropout(),
49+
nn.Linear(4096, 4096),
50+
nn.ReLU(),
51+
nn.Dropout(),
52+
nn.Linear(4096, num_classes),
53+
)
54+
if init_weights:
55+
self._initialize_weights()
56+
57+
def _initialize_weights(self):
58+
for m in self.modules():
59+
if isinstance(m, nn.Conv2d):
60+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
61+
if m.bias is not None:
62+
nn.init.constant_(m.bias, 0)
63+
elif isinstance(m, nn.BatchNorm2d):
64+
nn.init.constant_(m.weight, 1)
65+
nn.init.constant_(m.bias, 0)
66+
elif isinstance(m, nn.Linear):
67+
nn.init.normal_(m.weight, 0, 0.01)
68+
nn.init.constant_(m.bias, 0)
69+
70+
def forward(self, x):
71+
x = self.features(x)
72+
x = self.avgpool(x)
73+
x = torch.flatten(x, 1)
74+
x = self.classifier(x)
75+
return x
76+
77+
78+
def vgg16(num_classes=1000, init_weights=False):
79+
vgg16_cfg = [
80+
64,
81+
64,
82+
"pool",
83+
128,
84+
128,
85+
"pool",
86+
256,
87+
256,
88+
256,
89+
"pool",
90+
512,
91+
512,
92+
512,
93+
"pool",
94+
512,
95+
512,
96+
512,
97+
"pool",
98+
]
99+
return VGG(vgg16_cfg, num_classes, init_weights)
100+
101+
102+
PARSER = argparse.ArgumentParser(
103+
description="Load pre-trained VGG model and then tune with FP8 and PTQ"
104+
)
105+
PARSER.add_argument(
106+
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
107+
)
108+
PARSER.add_argument(
109+
"--batch-size",
110+
default=128,
111+
type=int,
112+
help="Batch size for tuning the model with PTQ and FP8",
113+
)
114+
115+
args = PARSER.parse_args()
116+
117+
model = vgg16(num_classes=10, init_weights=False)
118+
model = model.cuda()
119+
120+
# %%
121+
# Load the pre-trained model weights
122+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
123+
124+
ckpt = torch.load(args.ckpt)
125+
weights = ckpt["model_state_dict"]
126+
127+
if torch.cuda.device_count() > 1:
128+
from collections import OrderedDict
129+
130+
new_state_dict = OrderedDict()
131+
for k, v in weights.items():
132+
name = k[7:] # remove `module.`
133+
new_state_dict[name] = v
134+
weights = new_state_dict
135+
136+
model.load_state_dict(weights)
137+
# Don't forget to set the model to evaluation mode!
138+
model.eval()
139+
140+
# %%
141+
# Load training dataset and define loss function for PTQ
142+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
143+
144+
training_dataset = datasets.CIFAR10(
145+
root="./data",
146+
train=True,
147+
download=True,
148+
transform=transforms.Compose(
149+
[
150+
transforms.RandomCrop(32, padding=4),
151+
transforms.RandomHorizontalFlip(),
152+
transforms.ToTensor(),
153+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
154+
]
155+
),
156+
)
157+
training_dataloader = torch.utils.data.DataLoader(
158+
training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2
159+
)
160+
161+
data = iter(training_dataloader)
162+
images, _ = next(data)
163+
164+
crit = nn.CrossEntropyLoss()
165+
166+
# %%
167+
# Define Calibration Loop for quantization
168+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
169+
170+
171+
def calibrate_loop(model):
172+
# calibrate over the training dataset
173+
total = 0
174+
correct = 0
175+
loss = 0.0
176+
for data, labels in training_dataloader:
177+
data, labels = data.cuda(), labels.cuda(non_blocking=True)
178+
out = model(data)
179+
loss += crit(out, labels)
180+
preds = torch.max(out, 1)[1]
181+
total += labels.size(0)
182+
correct += (preds == labels).sum().item()
183+
184+
print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))
185+
186+
187+
# %%
188+
# Tune the pre-trained model with FP8 and PTQ
189+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
190+
191+
quant_cfg = mtq.FP8_DEFAULT_CFG
192+
# PTQ with in-place replacement to quantized modules
193+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
194+
# model has FP8 qdq nodes at this point
195+
196+
# %%
197+
# Inference
198+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
199+
200+
# Load the testing dataset
201+
testing_dataset = datasets.CIFAR10(
202+
root="./data",
203+
train=False,
204+
download=True,
205+
transform=transforms.Compose(
206+
[
207+
transforms.ToTensor(),
208+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
209+
]
210+
),
211+
)
212+
213+
testing_dataloader = torch.utils.data.DataLoader(
214+
testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
215+
)
216+
217+
with torch.no_grad():
218+
with export_torch_mode():
219+
# Compile the model with Torch-TensorRT Dynamo backend
220+
input_tensor = images.cuda()
221+
exp_program = torch.export.export(model, (input_tensor,))
222+
trt_model = torchtrt.dynamo.compile(
223+
exp_program,
224+
inputs=[input_tensor],
225+
enabled_precisions={torch.float8_e4m3fn},
226+
min_block_size=1,
227+
debug=False,
228+
)
229+
230+
# Inference compiled Torch-TensorRT model over the testing dataset
231+
total = 0
232+
correct = 0
233+
loss = 0.0
234+
class_probs = []
235+
class_preds = []
236+
model.eval()
237+
for data, labels in testing_dataloader:
238+
data, labels = data.cuda(), labels.cuda(non_blocking=True)
239+
out = model(data)
240+
loss += crit(out, labels)
241+
preds = torch.max(out, 1)[1]
242+
class_probs.append([F.softmax(i, dim=0) for i in out])
243+
class_preds.append(preds)
244+
total += labels.size(0)
245+
correct += (preds == labels).sum().item()
246+
247+
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
248+
test_preds = torch.cat(class_preds)
249+
test_loss = loss / total
250+
test_acc = correct / total
251+
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

‎examples/int8/training/vgg16/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ nvidia-pyindex
44
--extra-index-url https://pypi.nvidia.com
55
pytorch-quantization
66
tqdm
7+
nvidia-modelopt
8+
--extra-index-url https://pypi.nvidia.com

‎packaging/pre_build_script.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Install dependencies
44
python3 -m pip install pyyaml
55
yum install -y ninja-build gettext
6-
TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()")
76

87
wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \
98
&& mv bazelisk-linux-amd64 /usr/bin/bazel \

‎packaging/pre_build_script_windows.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
python -m pip install -U numpy packaging pyyaml setuptools wheel
22

33
# Install TRT 10 from PyPi
4-
python -m pip install tensorrt==10.0.0b6 tensorrt-${CU_VERSION::4}-bindings==10.0.0b6 tensorrt-${CU_VERSION::4}-libs==10.0.0b6 --extra-index-url https://pypi.nvidia.com
4+
python -m pip install tensorrt==10.0.1 --extra-index-url https://pypi.nvidia.com
55

66
choco install bazelisk -y
77

‎py/torch_tensorrt/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from torch_tensorrt._version import ( # noqa: F401
88
__cuda_version__,
9-
__cudnn_version__,
109
__tensorrt_version__,
1110
__version__,
1211
)
@@ -40,11 +39,9 @@ def _find_lib(name: str, paths: List[str]) -> str:
4039
import tensorrt # noqa: F401
4140
except ImportError:
4241
cuda_version = _parse_semver(__cuda_version__)
43-
cudnn_version = _parse_semver(__cudnn_version__)
4442
tensorrt_version = _parse_semver(__tensorrt_version__)
4543

4644
CUDA_MAJOR = cuda_version["major"]
47-
CUDNN_MAJOR = cudnn_version["major"]
4845
TENSORRT_MAJOR = tensorrt_version["major"]
4946

5047
if sys.platform.startswith("win"):

‎py/torch_tensorrt/_enums.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class dtype(Enum):
2323
f32 = auto()
2424
f64 = auto()
2525
b = auto()
26+
27+
f8 = auto()
2628
bf16 = auto()
27-
# TODO: Enable FP8
28-
# f8 = auto()
2929

3030
uint8 = u8
3131
int8 = i8
@@ -35,6 +35,9 @@ class dtype(Enum):
3535
long = i64
3636
int64 = i64
3737

38+
float8 = f8
39+
fp8 = f8
40+
3841
half = f16
3942
fp16 = f16
4043
float16 = f16
@@ -47,10 +50,6 @@ class dtype(Enum):
4750
fp64 = f64
4851
float64 = f64
4952

50-
# TODO: Enable when FP8 is enabled
51-
# float8 = f8
52-
# fp8 = f8
53-
5453
bfloat16 = bf16
5554

5655
@staticmethod
@@ -78,6 +77,8 @@ def _from(
7877
return dtype.i64
7978
elif t == torch.int32:
8079
return dtype.i32
80+
elif t == torch.float8_e4m3fn:
81+
return dtype.f8
8182
elif t == torch.half:
8283
return dtype.f16
8384
elif t == torch.float:
@@ -102,6 +103,8 @@ def _from(
102103
return dtype.u8
103104
elif t == trt.DataType.INT8:
104105
return dtype.i8
106+
elif t == trt.DataType.FP8:
107+
return dtype.f8
105108
elif t == trt.DataType.INT32:
106109
return dtype.i32
107110
elif t == trt.DataType.INT64:
@@ -209,6 +212,8 @@ def to(
209212
return torch.int
210213
elif self == dtype.i64:
211214
return torch.long
215+
elif self == dtype.f8:
216+
return torch.float8_e4m3fn
212217
elif self == dtype.f16:
213218
return torch.half
214219
elif self == dtype.f32:
@@ -234,6 +239,8 @@ def to(
234239
return trt.DataType.INT8
235240
elif self == dtype.i32:
236241
return trt.DataType.INT32
242+
elif self == dtype.f8:
243+
return trt.DataType.FP8
237244
elif self == dtype.i64:
238245
return trt.DataType.INT64
239246
elif self == dtype.f16:

‎py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
2929
DYNAMO_CONVERTERS as CONVERTERS,
3030
)
31-
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
31+
from torch_tensorrt.dynamo.lowering import (
32+
get_decompositions,
33+
post_lowering,
34+
pre_export_lowering,
35+
)
3236
from torch_tensorrt.dynamo.utils import (
3337
get_torch_inputs,
3438
parse_complex_tensor_structs,
@@ -167,22 +171,22 @@ def compile(
167171

168172
# Prepare torch_trt inputs
169173
inputs = prepare_inputs(inputs)
174+
torch_inputs = get_torch_inputs(inputs, device)
170175
device = to_torch_tensorrt_device(device)
171176
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
172177

173178
if not isinstance(exported_program, ExportedProgram):
174179
raise AssertionError(
175180
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
176181
)
182+
exported_program = pre_export_lowering(exported_program, torch_inputs)
177183
exported_program = exported_program.run_decompositions(
178184
get_decompositions(enable_experimental_decompositions)
179185
)
180186
gm = exported_program.module()
181187
logger.debug("Input graph: " + str(gm.graph))
182188
# Apply lowering on the graph module
183-
torch_inputs = get_torch_inputs(inputs, device)
184-
gm = apply_lowering_passes(gm, torch_inputs)
185-
189+
gm = post_lowering(gm, torch_inputs)
186190
logger.debug("Lowered Input graph: " + str(gm.graph))
187191

188192
compilation_options = {
@@ -553,7 +557,7 @@ def convert_module_to_trt_engine(
553557
# Prepare torch_trt inputs
554558
input_list = prepare_inputs(input_list)
555559
device = to_torch_tensorrt_device(device)
556-
560+
torch_inputs = get_torch_inputs(input_list, device)
557561
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
558562

559563
compilation_options = {
@@ -583,6 +587,7 @@ def convert_module_to_trt_engine(
583587
"dla_global_dram_size": dla_global_dram_size,
584588
}
585589

590+
exported_program = pre_export_lowering(exported_program, torch_inputs)
586591
# Decompose the exported program
587592
exported_program = exported_program.run_decompositions(
588593
get_decompositions(enable_experimental_decompositions)
@@ -591,8 +596,7 @@ def convert_module_to_trt_engine(
591596
logger.debug("Input graph: " + str(gm.graph))
592597

593598
# Apply lowering on the graph module
594-
torch_inputs = get_torch_inputs(input_list, device)
595-
gm = apply_lowering_passes(gm, torch_inputs)
599+
gm = post_lowering(gm, torch_inputs)
596600
logger.debug("Lowered Input graph: " + str(gm.graph))
597601

598602
settings = CompilationSettings(**compilation_options)

‎py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
REQUIRE_FULL_COMPILATION = False
2828
DRYRUN = False
2929
HARDWARE_COMPATIBLE = False
30-
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16}
30+
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
3131

3232

3333
def default_device() -> Device:

‎py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
from torch_tensorrt.dynamo import CompilationSettings
1212
from torch_tensorrt.dynamo._compiler import compile_module
1313
from torch_tensorrt.dynamo.lowering import (
14-
apply_lowering_passes,
1514
get_decompositions,
15+
post_lowering,
16+
remove_detach,
1617
remove_sym_nodes,
1718
repair_input_aliasing,
1819
)
@@ -82,6 +83,9 @@ def _pretraced_backend(
8283
input for input in sample_inputs if isinstance(input, torch.Tensor)
8384
]
8485

86+
# Remove detach nodes
87+
remove_detach(gm, torch_inputs)
88+
8589
# Invoke AOTAutograd to translate operators to aten
8690
gm = aot_export_joint_simple(
8791
gm,
@@ -94,7 +98,7 @@ def _pretraced_backend(
9498

9599
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
96100

97-
gm = apply_lowering_passes(gm, torch_inputs)
101+
gm = post_lowering(gm, sample_inputs)
98102

99103
logger.debug("Lowered Input graph:\n " + str(gm.graph))
100104

‎py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ def __init__(
106106
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
107107
)
108108

109-
_LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}")
110-
111109
def validate_conversion(self) -> Set[str]:
112110
missing_converters: Set[str] = set()
113111

@@ -243,6 +241,8 @@ def _populate_trt_builder_config(
243241
if dtype.int8 in self.compilation_settings.enabled_precisions:
244242
builder_config.set_flag(trt.BuilderFlag.INT8)
245243

244+
if dtype.fp8 in self.compilation_settings.enabled_precisions:
245+
builder_config.set_flag(trt.BuilderFlag.FP8)
246246
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
247247
builder_config.set_flag(trt.BuilderFlag.BF16)
248248

‎py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,34 @@ def aten_ops_neg(
579579
)
580580

581581

582+
try:
583+
import modelopt.torch.quantization as mtq
584+
585+
assert torch.ops.trt.quantize_fp8.default
586+
except Exception as e:
587+
_LOGGER.warning(
588+
"Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
589+
)
590+
else:
591+
592+
@dynamo_tensorrt_converter(torch.ops.trt.quantize_fp8.default)
593+
def aten_ops_quantize_fp8(
594+
ctx: ConversionContext,
595+
target: Target,
596+
args: Tuple[Argument, ...],
597+
kwargs: Dict[str, Argument],
598+
name: str,
599+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
600+
return impl.quantize.quantize_fp8(
601+
ctx,
602+
target,
603+
SourceIR.ATEN,
604+
name,
605+
args[0],
606+
args[1],
607+
)
608+
609+
582610
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
583611
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
584612
def aten_ops_squeeze(

‎py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
pad,
1919
permutation,
2020
pool,
21+
quantize,
2122
reduce,
2223
select,
2324
shape,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
import tensorrt as trt
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
9+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
10+
from torch_tensorrt.fx.types import TRTTensor
11+
12+
13+
def quantize_fp8(
14+
ctx: ConversionContext,
15+
target: Target,
16+
source_ir: Optional[SourceIR],
17+
name: str,
18+
input_tensor: TRTTensor,
19+
scale: np.ndarray,
20+
) -> TRTTensor:
21+
"""
22+
Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based
23+
on the output_type set and dequantizes them back.
24+
"""
25+
if (isinstance(input_tensor, TRTTensor)) and not (
26+
input_tensor.dtype == trt.float32 or input_tensor.dtype == trt.float16
27+
):
28+
raise ValueError(
29+
f"quantize_fp8 converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
30+
)
31+
32+
scale = get_trt_tensor(ctx, scale, name + "_scale")
33+
# Add Q node
34+
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
35+
quantize_layer.set_output_type(0, trt.DataType.FP8)
36+
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
37+
q_output = quantize_layer.get_output(0)
38+
# Add DQ node
39+
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
40+
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
41+
# Set DQ layer precision to FP8
42+
dequantize_layer.precision = trt.DataType.FP8
43+
dq_output = dequantize_layer.get_output(0)
44+
45+
return dq_output

‎py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from ._decompositions import get_decompositions # noqa: F401
66
from ._remove_sym_nodes import remove_sym_nodes
77
from ._repair_input_aliasing import repair_input_aliasing
8-
from .passes import apply_lowering_passes
8+
from .passes import post_lowering, pre_export_lowering
9+
from .passes.remove_detach import remove_detach

‎py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@ def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1010
dynamic=True behavior
1111
"""
1212
# Extract SymInt placeholder Tensors
13-
placeholders = [
13+
placeholder_sym_ints = [
1414
node
1515
for node in gm.graph.nodes
1616
if (
1717
node.op == "placeholder"
1818
and isinstance(node.type, type)
1919
and issubclass(node.type, torch.SymInt)
20+
and not node.users
2021
)
2122
]
2223

23-
for node in placeholders:
24+
for node in placeholder_sym_ints:
2425
gm.graph.erase_node(node)
2526

2627
gm.graph.lint()

‎py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from .lower_linear import lower_linear
99
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1010
from .pass_manager import DynamoPassManager
11+
from .remove_detach import remove_detach
1112
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1213
from .repair_input_as_output import repair_input_as_output
1314
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1415
from .view_to_reshape import view_to_reshape
1516

16-
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
17+
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1718
[
1819
remove_input_alias_fixing_clones,
1920
constant_fold,
@@ -26,6 +27,12 @@
2627
]
2728
)
2829

30+
ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
31+
[
32+
remove_detach,
33+
]
34+
)
35+
2936
logger = logging.getLogger(__name__)
3037

3138

@@ -48,9 +55,9 @@ def _aten_lowering_pass(
4855
def add_lowering_pass(
4956
lowering_pass: LoweringPassSignature,
5057
) -> LoweringPassSignature:
51-
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
58+
ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
5259
logger.debug(
53-
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
60+
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
5461
)
5562
return lowering_pass
5663

@@ -72,23 +79,35 @@ def add_lowering_pass(
7279

7380
def _remove_lowering_pass(*, index: int) -> None:
7481
"""Removes a lowering pass at a specific index from the registry"""
75-
ATEN_LOWERING_PASSES.remove_pass_with_index(index)
82+
ATEN_POST_LOWERING_PASSES.remove_pass_with_index(index)
7683
logger.debug(
77-
f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
84+
f"Removed lowering pass at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
7885
)
7986
return
8087

8188

82-
def apply_lowering_passes(
89+
def post_lowering(
8390
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
8491
) -> torch.fx.GraphModule:
85-
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
92+
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
93+
logging.debug(
94+
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
95+
)
96+
return ATEN_POST_LOWERING_PASSES(gm, sample_inputs)
97+
98+
99+
def pre_export_lowering(
100+
ep: torch.export.ExportedProgram, sample_inputs: Sequence[torch.Tensor]
101+
) -> torch.fx.GraphModule:
102+
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
86103
logging.debug(
87-
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
104+
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
88105
)
89-
return ATEN_LOWERING_PASSES(gm, sample_inputs)
106+
gm = ep.graph_module
107+
gm = ATEN_PRE_LOWERING_PASSES(gm, sample_inputs)
108+
return ep
90109

91110

92111
def dump_lowering_passes() -> str:
93112
"""Returns a string containing the lowering passes"""
94-
return str(ATEN_LOWERING_PASSES)
113+
return str(ATEN_POST_LOWERING_PASSES)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def remove_detach(
10+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
11+
) -> torch.fx.GraphModule:
12+
"""Remove detach ops in the graph"""
13+
count = 0
14+
for node in gm.graph.nodes:
15+
# node.target = "detach" in torch.compile workflow
16+
if node.target == torch.ops.aten.detach.default or node.target == "detach":
17+
# Detach node has only one input
18+
node_input = node.all_input_nodes[0]
19+
node.replace_all_uses_with(node_input)
20+
gm.graph.erase_node(node)
21+
count += 1
22+
23+
logger.debug(f"Removed {count} detach nodes:\n{gm.graph}")
24+
25+
return gm

‎tests/py/dynamo/conversion/harness.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
1818
from torch_tensorrt.dynamo.conversion import TRTInterpreter
1919
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
20-
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
20+
from torch_tensorrt.dynamo.lowering import (
21+
get_decompositions,
22+
post_lowering,
23+
pre_export_lowering,
24+
)
2125
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
2226
from torch_tensorrt.dynamo.utils import get_torch_inputs
2327

@@ -210,14 +214,16 @@ def generate_graph(
210214
torch_inputs = get_torch_inputs(original_inputs, _defaults.DEVICE)
211215
if use_dynamo_tracer:
212216
exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs))
217+
exported_program = pre_export_lowering(exported_program, torch_inputs)
213218
exported_program = exported_program.run_decompositions(
214219
get_decompositions(False)
215220
)
216221
fx_module = exported_program.module()
217222
else:
218223
fx_module = torch.fx.symbolic_trace(mod)
224+
219225
if enable_passes:
220-
fx_module = apply_lowering_passes(fx_module, torch_inputs)
226+
fx_module = post_lowering(fx_module, original_inputs)
221227

222228
if propagate_shapes:
223229
# TODO: This is currently being used to test embedding_bag_aten due to https://github.com/pytorch/TensorRT/issues/2843

‎tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import unittest
33

44
import torch
5-
from torch.testing._internal.common_utils import TestCase, run_tests
6-
75
import torch_tensorrt
6+
from torch.testing._internal.common_utils import TestCase, run_tests
87

98
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
109

@@ -397,6 +396,9 @@ def forward(self, q, k, v):
397396

398397

399398
class TestLowerLinear(TestCase):
399+
@unittest.skip(
400+
"This test has threshold failures. This is tracked at https://github.com/pytorch/TensorRT/issues/2715",
401+
)
400402
def test_lower_linear(self):
401403
class Linear(torch.nn.Module):
402404
def forward(self, input, weight, bias):
@@ -464,6 +466,9 @@ def forward(self, input, weight, bias):
464466
)
465467
torch._dynamo.reset()
466468

469+
@unittest.skip(
470+
"This test has threshold failures. This is tracked at https://github.com/pytorch/TensorRT/issues/2715",
471+
)
467472
def test_lower_linear_batch(self):
468473
class Linear(torch.nn.Module):
469474
def forward(self, input, weight, bias):

‎tests/py/dynamo/models/test_models_export.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,50 @@ def test_resnet18_half(ir):
182182

183183
# Clean up model env
184184
torch._dynamo.reset()
185+
186+
187+
@unittest.skipIf(
188+
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
189+
"FP8 compilation in Torch-TRT is not supported on cards older than Hopper",
190+
)
191+
@pytest.mark.unit
192+
def test_base_fp8(ir):
193+
class SimpleNetwork(torch.nn.Module):
194+
def __init__(self):
195+
super(SimpleNetwork, self).__init__()
196+
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
197+
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
198+
199+
def forward(self, x):
200+
x = self.linear1(x)
201+
x = torch.nn.ReLU()(x)
202+
x = self.linear2(x)
203+
return x
204+
205+
import modelopt.torch.quantization as mtq
206+
from modelopt.torch.quantization.utils import export_torch_mode
207+
208+
def calibrate_loop(model):
209+
"""Simple calibration function for testing."""
210+
model(input_tensor)
211+
212+
input_tensor = torch.randn(1, 10).cuda()
213+
model = SimpleNetwork().eval().cuda()
214+
215+
quant_cfg = mtq.FP8_DEFAULT_CFG
216+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
217+
# model has FP8 qdq nodes at this point
218+
output_pyt = model(input_tensor)
219+
220+
with torch.no_grad():
221+
with export_torch_mode():
222+
exp_program = torch.export.export(model, (input_tensor,))
223+
trt_model = torchtrt.dynamo.compile(
224+
exp_program,
225+
inputs=[input_tensor],
226+
enabled_precisions={torch.float8_e4m3fn},
227+
min_block_size=1,
228+
debug=True,
229+
)
230+
outputs_trt = trt_model(input_tensor)
231+
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)

‎tests/py/dynamo/testing_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from torch._functorch.aot_autograd import aot_export_joint_simple
99
from torch_tensorrt.dynamo import partitioning
1010
from torch_tensorrt.dynamo.lowering import (
11-
apply_lowering_passes,
1211
get_decompositions,
12+
post_lowering,
1313
repair_input_aliasing,
1414
)
1515

@@ -50,7 +50,7 @@ def fx_dynamo_testing_backend(
5050
decompositions=get_decompositions(),
5151
)
5252

53-
gm = apply_lowering_passes(gm, sample_inputs)
53+
gm = post_lowering(gm, sample_inputs)
5454

5555
trt_compiled = custom_backend(
5656
gm,

‎tests/py/requirements.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
# This file is specifically to install correct version of libraries during CI testing.
2+
# The index url for torch & torchvision libs is configured in install-torch-tensorrt.sh based on CUDA version
3+
# networkx library issue: https://discuss.pytorch.org/t/installing-pytorch-under-python-3-8-question-about-networkx-version/196740
14
pytest>=8.2.1
25
pytest-xdist>=3.6.1
6+
networkx==2.8.8
7+
torch==2.3.0
8+
torchvision==0.18.0
9+
--extra-index-url https://pypi.ngc.nvidia.com
10+
pyyaml
11+
tensorrt==10.0.1
312
timm>=1.0.3
413
transformers==4.39.3
514
parameterized>=0.2.0

‎versions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
__cudnn_version__ = "0.0"
1212
__tensorrt_version__ = "0.0"
1313

14-
1514
LEADING_V_PATTERN = re.compile("^v")
1615
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
1716
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")

0 commit comments

Comments
 (0)
Please sign in to comment.