@@ -199,5 +199,62 @@ def forward(self, x):
199
199
)
200
200
201
201
202
+ class TestTRTModuleFloat64Input (TestCase ):
203
+ def test_save_and_load_trt_module (self ):
204
+ class TestModule (torch .nn .Module ):
205
+ def forward (self , x ):
206
+ return x + x
207
+
208
+ inputs = [torch .randn (5 , 5 ).double ()]
209
+ mod = TestModule ().eval ()
210
+ ref_output = mod (* inputs )
211
+
212
+ mod = acc_tracer .trace (mod , inputs )
213
+ interp = TRTInterpreter (
214
+ mod ,
215
+ input_specs = InputTensorSpec .from_tensors (inputs ),
216
+ )
217
+ trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
218
+ torch .save (trt_mod , "trt.pt" )
219
+ reload_trt_mod = torch .load ("trt.pt" )
220
+
221
+ torch .testing .assert_close (
222
+ reload_trt_mod (inputs [0 ].cuda ()).cpu (),
223
+ ref_output ,
224
+ rtol = 1e-04 ,
225
+ atol = 1e-04 ,
226
+ check_dtype = False ,
227
+ )
228
+ os .remove (f"{ os .getcwd ()} /trt.pt" )
229
+
230
+ def test_save_and_load_state_dict (self ):
231
+ class TestModule (torch .nn .Module ):
232
+ def forward (self , x ):
233
+ return x + x
234
+
235
+ inputs = [torch .randn (5 , 5 ).double ()]
236
+ mod = TestModule ().eval ()
237
+ ref_output = mod (* inputs )
238
+
239
+ mod = acc_tracer .trace (mod , inputs )
240
+ interp = TRTInterpreter (
241
+ mod ,
242
+ input_specs = InputTensorSpec .from_tensors (inputs ),
243
+ )
244
+ trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
245
+ st = trt_mod .state_dict ()
246
+
247
+ new_trt_mod = TRTModule ()
248
+ new_trt_mod .load_state_dict (st )
249
+
250
+ torch .testing .assert_close (
251
+ new_trt_mod (inputs [0 ].cuda ()).cpu (),
252
+ ref_output ,
253
+ rtol = 1e-04 ,
254
+ atol = 1e-04 ,
255
+ check_dtype = False ,
256
+ )
257
+
258
+
202
259
if __name__ == "__main__" :
203
260
run_tests ()
0 commit comments