|
| 1 | +#include "module_test.h" |
| 2 | + |
| 3 | +TEST_P(ModuleTests, FP16CompiledModuleIsClose) { |
| 4 | + std::vector<torch::jit::IValue> jit_inputs_ivalues; |
| 5 | + std::vector<torch::jit::IValue> trt_inputs_ivalues; |
| 6 | + for (auto in_shape : input_shapes) { |
| 7 | + auto in = at::randint(5, in_shape, {at::kCUDA}); |
| 8 | + in = in.to(torch::kF16); |
| 9 | + jit_inputs_ivalues.push_back(in.clone()); |
| 10 | + trt_inputs_ivalues.push_back(in.clone()); |
| 11 | + } |
| 12 | + |
| 13 | + auto extra_info = trtorch::ExtraInfo({input_shapes}); |
| 14 | + extra_info.op_precision = torch::kF16; |
| 15 | + extra_info.strict_types = true; |
| 16 | + |
| 17 | + auto trt_mod = trtorch::CompileGraph(mod, extra_info); |
| 18 | + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); |
| 19 | + std::vector<at::Tensor> trt_results; |
| 20 | + trt_results.push_back(trt_results_ivalues.toTensor()); |
| 21 | + |
| 22 | + mod.to(torch::kF16); |
| 23 | + torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, jit_inputs_ivalues); |
| 24 | + std::vector<at::Tensor> jit_results; |
| 25 | + jit_results.push_back(jit_results_ivalues.toTensor()); |
| 26 | + |
| 27 | + for (size_t i = 0; i < trt_results.size(); i++) { |
| 28 | + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 2e-5)); |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | + |
| 33 | +INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite, |
| 34 | + ModuleTests, |
| 35 | + testing::Values( |
| 36 | + PathAndInSize({"tests/modules/resnet18.jit.pt", |
| 37 | + {{1,3,224,224}}}), |
| 38 | + PathAndInSize({"tests/modules/resnet50.jit.pt", |
| 39 | + {{1,3,224,224}}}), |
| 40 | + PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", |
| 41 | + {{1,3,224,224}}}))); |
0 commit comments