11
11
import torchvision .models as models
12
12
from torch_tensorrt .dynamo .utils import COSINE_THRESHOLD , cosine_similarity
13
13
from transformers import BertModel
14
- from transformers .utils .fx import symbolic_trace as transformers_trace
15
14
16
15
from packaging .version import Version
17
16
@@ -196,16 +195,18 @@ def test_resnet18_half(ir):
196
195
197
196
198
197
@unittest .skipIf (
199
- torch .cuda .get_device_properties ( torch . cuda . current_device ()). major < 9 ,
200
- "FP8 compilation in Torch-TRT is not supported on cards older than Hopper " ,
198
+ torch .cuda .get_device_capability () < ( 8 , 9 ) ,
199
+ "FP8 quantization requires compute capability 8.9 or later " ,
201
200
)
202
201
@unittest .skipIf (
203
202
not importlib .util .find_spec ("modelopt" ),
204
- reason = "ModelOpt is necessary to run this test" ,
203
+ "ModelOpt is required to run this test" ,
205
204
)
206
205
@pytest .mark .unit
207
206
def test_base_fp8 (ir ):
208
- import modelopt
207
+ import modelopt .torch .quantization as mtq
208
+ from modelopt .torch .quantization .utils import export_torch_mode
209
+ from torch .export ._trace import _export
209
210
210
211
class SimpleNetwork (torch .nn .Module ):
211
212
def __init__ (self ):
@@ -219,9 +220,6 @@ def forward(self, x):
219
220
x = self .linear2 (x )
220
221
return x
221
222
222
- import modelopt .torch .quantization as mtq
223
- from modelopt .torch .quantization .utils import export_torch_mode
224
-
225
223
def calibrate_loop (model ):
226
224
"""Simple calibration function for testing."""
227
225
model (input_tensor )
@@ -236,7 +234,7 @@ def calibrate_loop(model):
236
234
237
235
with torch .no_grad ():
238
236
with export_torch_mode ():
239
- exp_program = torch . export . export (model , (input_tensor ,))
237
+ exp_program = _export (model , (input_tensor ,))
240
238
trt_model = torchtrt .dynamo .compile (
241
239
exp_program ,
242
240
inputs = [input_tensor ],
@@ -247,7 +245,7 @@ def calibrate_loop(model):
247
245
reuse_cached_engines = False ,
248
246
)
249
247
outputs_trt = trt_model (input_tensor )
250
- assert torch .allclose (output_pyt , outputs_trt , rtol = 1e -3 , atol = 1e-2 )
248
+ assert torch .allclose (output_pyt , outputs_trt , rtol = 5e -3 , atol = 1e-2 )
251
249
252
250
253
251
@unittest .skipIf (
@@ -258,7 +256,9 @@ def calibrate_loop(model):
258
256
)
259
257
@pytest .mark .unit
260
258
def test_base_int8 (ir ):
261
- import modelopt
259
+ import modelopt .torch .quantization as mtq
260
+ from modelopt .torch .quantization .utils import export_torch_mode
261
+ from torch .export ._trace import _export
262
262
263
263
class SimpleNetwork (torch .nn .Module ):
264
264
def __init__ (self ):
@@ -272,9 +272,6 @@ def forward(self, x):
272
272
x = self .linear2 (x )
273
273
return x
274
274
275
- import modelopt .torch .quantization as mtq
276
- from modelopt .torch .quantization .utils import export_torch_mode
277
-
278
275
def calibrate_loop (model ):
279
276
"""Simple calibration function for testing."""
280
277
model (input_tensor )
@@ -289,8 +286,6 @@ def calibrate_loop(model):
289
286
290
287
with torch .no_grad ():
291
288
with export_torch_mode ():
292
- from torch .export ._trace import _export
293
-
294
289
exp_program = _export (model , (input_tensor ,))
295
290
trt_model = torchtrt .dynamo .compile (
296
291
exp_program ,
@@ -302,4 +297,4 @@ def calibrate_loop(model):
302
297
reuse_cached_engines = False ,
303
298
)
304
299
outputs_trt = trt_model (input_tensor )
305
- assert torch .allclose (output_pyt , outputs_trt , rtol = 1e -3 , atol = 1e-2 )
300
+ assert torch .allclose (output_pyt , outputs_trt , rtol = 5e -3 , atol = 1e-2 )
0 commit comments