Skip to content

Commit 6ba44fa

Browse files
authored
Run test_base_fp8 for compute capability 8.9 or later (#3164)
1 parent 6da143c commit 6ba44fa

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

tests/py/dynamo/models/test_models_export.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torchvision.models as models
1212
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
1313
from transformers import BertModel
14-
from transformers.utils.fx import symbolic_trace as transformers_trace
1514

1615
from packaging.version import Version
1716

@@ -196,16 +195,18 @@ def test_resnet18_half(ir):
196195

197196

198197
@unittest.skipIf(
199-
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
200-
"FP8 compilation in Torch-TRT is not supported on cards older than Hopper",
198+
torch.cuda.get_device_capability() < (8, 9),
199+
"FP8 quantization requires compute capability 8.9 or later",
201200
)
202201
@unittest.skipIf(
203202
not importlib.util.find_spec("modelopt"),
204-
reason="ModelOpt is necessary to run this test",
203+
"ModelOpt is required to run this test",
205204
)
206205
@pytest.mark.unit
207206
def test_base_fp8(ir):
208-
import modelopt
207+
import modelopt.torch.quantization as mtq
208+
from modelopt.torch.quantization.utils import export_torch_mode
209+
from torch.export._trace import _export
209210

210211
class SimpleNetwork(torch.nn.Module):
211212
def __init__(self):
@@ -219,9 +220,6 @@ def forward(self, x):
219220
x = self.linear2(x)
220221
return x
221222

222-
import modelopt.torch.quantization as mtq
223-
from modelopt.torch.quantization.utils import export_torch_mode
224-
225223
def calibrate_loop(model):
226224
"""Simple calibration function for testing."""
227225
model(input_tensor)
@@ -236,7 +234,7 @@ def calibrate_loop(model):
236234

237235
with torch.no_grad():
238236
with export_torch_mode():
239-
exp_program = torch.export.export(model, (input_tensor,))
237+
exp_program = _export(model, (input_tensor,))
240238
trt_model = torchtrt.dynamo.compile(
241239
exp_program,
242240
inputs=[input_tensor],
@@ -247,7 +245,7 @@ def calibrate_loop(model):
247245
reuse_cached_engines=False,
248246
)
249247
outputs_trt = trt_model(input_tensor)
250-
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
248+
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)
251249

252250

253251
@unittest.skipIf(
@@ -258,7 +256,9 @@ def calibrate_loop(model):
258256
)
259257
@pytest.mark.unit
260258
def test_base_int8(ir):
261-
import modelopt
259+
import modelopt.torch.quantization as mtq
260+
from modelopt.torch.quantization.utils import export_torch_mode
261+
from torch.export._trace import _export
262262

263263
class SimpleNetwork(torch.nn.Module):
264264
def __init__(self):
@@ -272,9 +272,6 @@ def forward(self, x):
272272
x = self.linear2(x)
273273
return x
274274

275-
import modelopt.torch.quantization as mtq
276-
from modelopt.torch.quantization.utils import export_torch_mode
277-
278275
def calibrate_loop(model):
279276
"""Simple calibration function for testing."""
280277
model(input_tensor)
@@ -289,8 +286,6 @@ def calibrate_loop(model):
289286

290287
with torch.no_grad():
291288
with export_torch_mode():
292-
from torch.export._trace import _export
293-
294289
exp_program = _export(model, (input_tensor,))
295290
trt_model = torchtrt.dynamo.compile(
296291
exp_program,
@@ -302,4 +297,4 @@ def calibrate_loop(model):
302297
reuse_cached_engines=False,
303298
)
304299
outputs_trt = trt_model(input_tensor)
305-
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
300+
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)

0 commit comments

Comments
 (0)